Unverified Commit 31b6f085 authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

Benchmarks: Code Revision - Support benchmark re-registration, keep the latest one. (#23)



* support benchmark re-registration.

* address comments
Co-authored-by: default avatarGuoshuai Zhao <guzhao@microsoft.com>
parent 5b9b5cc8
......@@ -6,7 +6,6 @@
from typing import Dict
from superbench.common.utils import logger
from superbench.common.errors import DuplicateBenchmarkRegistrationError
from superbench.benchmarks import Platform, Framework, BenchmarkContext
from superbench.benchmarks.base import Benchmark
......@@ -56,24 +55,17 @@ def register_benchmark(cls, name, class_def, parameters=None, platform=None):
name, platform_list, platform
)
)
if platform in cls.benchmarks[name]:
logger.warning('Duplicate registration - benchmark: {}, platform: {}'.format(name, platform))
if platform not in cls.benchmarks[name]:
cls.benchmarks[name][platform] = (class_def, parameters)
else:
logger.log_and_raise(
DuplicateBenchmarkRegistrationError,
'Duplicate registration - benchmark: {}, platform: {}'.format(name, platform)
)
cls.benchmarks[name][platform] = (class_def, parameters)
else:
# If not specified the tag, means the
# benchmark works for all platforms.
# If not specified the tag, means the benchmark works for all platforms.
for p in Platform:
if p not in cls.benchmarks[name]:
cls.benchmarks[name][p] = (class_def, parameters)
else:
logger.log_and_raise(
DuplicateBenchmarkRegistrationError, 'Duplicate registration - benchmark: {}'.format(name)
)
if p in cls.benchmarks[name]:
logger.warning('Duplicate registration - benchmark: {}, platform: {}'.format(name, p))
cls.benchmarks[name][p] = (class_def, parameters)
@classmethod
def is_benchmark_context_valid(cls, benchmark_context):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Exception types for SuperBench errors."""
class DuplicateBenchmarkRegistrationError(Exception):
"""An error is raised for duplicate benchmark registration."""
pass
......@@ -116,7 +116,6 @@ def create_benchmark(params='--num_steps=8'):
assert (name)
(benchmark_class, predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(name, context.platform)
assert (benchmark_class)
BenchmarkRegistry.clean_benchmarks()
return benchmark_class(name, predefine_params + ' ' + context.parameters)
......
......@@ -63,8 +63,6 @@ def test_register_benchmark():
context = BenchmarkContext('accumulation', platform)
assert (BenchmarkRegistry.is_benchmark_registered(context))
BenchmarkRegistry.clean_benchmarks()
# Register the benchmark for CUDA platform if use platform=Platform.CUDA.
BenchmarkRegistry.register_benchmark('accumulation-cuda', AccumulationBenchmark, platform=Platform.CUDA)
context = BenchmarkContext('accumulation-cuda', Platform.CUDA)
......@@ -72,8 +70,6 @@ def test_register_benchmark():
context = BenchmarkContext('accumulation-cuda', Platform.ROCM)
assert (BenchmarkRegistry.is_benchmark_registered(context) is False)
BenchmarkRegistry.clean_benchmarks()
def test_is_benchmark_context_valid():
"""Test interface BenchmarkRegistry.is_benchmark_context_valid()."""
......@@ -102,8 +98,6 @@ def test_get_benchmark_name():
name = BenchmarkRegistry._BenchmarkRegistry__get_benchmark_name(context)
assert (name == benchmark_names[i])
BenchmarkRegistry.clean_benchmarks()
def test_check_parameters():
"""Test interface BenchmarkRegistry.check_parameters()."""
......@@ -118,8 +112,6 @@ def test_check_parameters():
context = BenchmarkContext('accumulation', Platform.CPU, parameters='--lower=1')
assert (BenchmarkRegistry.check_parameters(context) is False)
BenchmarkRegistry.clean_benchmarks()
def test_get_benchmark_configurable_settings():
"""Test BenchmarkRegistry interface.
......@@ -139,8 +131,6 @@ def test_get_benchmark_configurable_settings():
--upper_bound int The upper bound for accumulation."""
assert (settings == expected)
BenchmarkRegistry.clean_benchmarks()
def test_launch_benchmark():
"""Test interface BenchmarkRegistry.launch_benchmark()."""
......@@ -194,10 +184,8 @@ def test_launch_benchmark():
)
assert (result == expected)
# Failed to launch benchmark.
# Failed to launch benchmark due to 'benchmark not found'.
context = BenchmarkContext(
'accumulation', Platform.CPU, parameters='--lower_bound=1 --upper_bound=4', framework=Framework.PYTORCH
'accumulation-fail', Platform.CPU, parameters='--lower_bound=1 --upper_bound=4', framework=Framework.PYTORCH
)
assert (BenchmarkRegistry.check_parameters(context) is False)
BenchmarkRegistry.clean_benchmarks()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment