Unverified Commit 8d24d03d authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

Benchmarks: Code Revision - Move benchmarks auto-registration from registry.py to __init__.py (#24)



* move benchmarks registration from registry.py to __init__.py

* revise __init__.
Co-authored-by: default avatarGuoshuai Zhao <guzhao@microsoft.com>
parent 5dfcc6be
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
"""Exposes interfaces of benchmarks used by SuperBench executor.""" """Exposes interfaces of benchmarks used by SuperBench executor."""
from .return_code import ReturnCode from superbench.benchmarks.return_code import ReturnCode
from .context import Platform, Framework, Precision, ModelAction, BenchmarkType, BenchmarkContext from superbench.benchmarks.context import Platform, Framework, Precision, ModelAction, BenchmarkType, BenchmarkContext
from .registry import BenchmarkRegistry from superbench.benchmarks.registry import BenchmarkRegistry
from superbench.benchmarks import model_benchmarks, micro_benchmarks, docker_benchmarks # noqa pylint: disable=unused-import
__all__ = [ __all__ = [
'ReturnCode', 'Platform', 'Framework', 'BenchmarkType', 'Precision', 'ModelAction', 'BenchmarkContext', 'ReturnCode', 'Platform', 'Framework', 'BenchmarkType', 'Precision', 'ModelAction', 'BenchmarkContext',
......
...@@ -3,6 +3,6 @@ ...@@ -3,6 +3,6 @@
"""A module containing all the benchmarks packaged in docker.""" """A module containing all the benchmarks packaged in docker."""
from .docker_base import DockerBenchmark from superbench.benchmarks.docker_benchmarks.docker_base import DockerBenchmark
__all__ = ['DockerBenchmark'] __all__ = ['DockerBenchmark']
...@@ -3,6 +3,6 @@ ...@@ -3,6 +3,6 @@
"""A module containing all the micro-benchmarks.""" """A module containing all the micro-benchmarks."""
from .micro_base import MicroBenchmark from superbench.benchmarks.micro_benchmarks.micro_base import MicroBenchmark
__all__ = ['MicroBenchmark'] __all__ = ['MicroBenchmark']
...@@ -3,6 +3,6 @@ ...@@ -3,6 +3,6 @@
"""A module containing all the e2e model related benchmarks.""" """A module containing all the e2e model related benchmarks."""
from .model_base import ModelBenchmark from superbench.benchmarks.model_benchmarks.model_base import ModelBenchmark
__all__ = ['ModelBenchmark'] __all__ = ['ModelBenchmark']
...@@ -8,9 +8,6 @@ ...@@ -8,9 +8,6 @@
from superbench.common.utils import logger from superbench.common.utils import logger
from superbench.benchmarks import Platform, Framework, BenchmarkContext from superbench.benchmarks import Platform, Framework, BenchmarkContext
from superbench.benchmarks.base import Benchmark from superbench.benchmarks.base import Benchmark
import superbench.benchmarks.model_benchmarks # noqa pylint: disable=unused-import
import superbench.benchmarks.micro_benchmarks # noqa pylint: disable=unused-import
import superbench.benchmarks.docker_benchmarks # noqa pylint: disable=unused-import
class BenchmarkRegistry: class BenchmarkRegistry:
......
...@@ -233,5 +233,3 @@ def test_pytorch_base(): ...@@ -233,5 +233,3 @@ def test_pytorch_base():
assert (isinstance(benchmark._optimizer, torch.optim.SGD)) assert (isinstance(benchmark._optimizer, torch.optim.SGD))
benchmark._optimizer_type = None benchmark._optimizer_type = None
assert (benchmark._create_optimizer() is False) assert (benchmark._create_optimizer() is False)
BenchmarkRegistry.clean_benchmarks()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment