Unverified Commit fb850af7 authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

Benchmarks: Add Feature - Add interface to get all predefine parameters of all benchmarks. (#56)

* Benchmarks: Add Feature - Add interface to get all predefine parameters of all benchmarks.
parent 435b2d5e
...@@ -23,7 +23,7 @@ def __init__(self, name, parameters=''): ...@@ -23,7 +23,7 @@ def __init__(self, name, parameters=''):
parameters (str): benchmark parameters. parameters (str): benchmark parameters.
""" """
self._name = name self._name = name
self._argv = list(filter(None, parameters.split(' '))) self._argv = list(filter(None, parameters.split(' '))) if parameters is not None else list()
self._benchmark_type = None self._benchmark_type = None
self._parser = argparse.ArgumentParser( self._parser = argparse.ArgumentParser(
add_help=False, add_help=False,
......
...@@ -23,7 +23,7 @@ class BenchmarkRegistry: ...@@ -23,7 +23,7 @@ class BenchmarkRegistry:
benchmarks: Dict[str, dict] = dict() benchmarks: Dict[str, dict] = dict()
@classmethod @classmethod
def register_benchmark(cls, name, class_def, parameters=None, platform=None): def register_benchmark(cls, name, class_def, parameters='', platform=None):
"""Register new benchmark, key is the benchmark name. """Register new benchmark, key is the benchmark name.
Args: Args:
...@@ -67,6 +67,18 @@ def register_benchmark(cls, name, class_def, parameters=None, platform=None): ...@@ -67,6 +67,18 @@ def register_benchmark(cls, name, class_def, parameters=None, platform=None):
cls.benchmarks[name][p] = (class_def, parameters) cls.benchmarks[name][p] = (class_def, parameters)
benchmark = class_def(name, parameters)
benchmark.add_parser_arguments()
ret, args, unknown = benchmark.parse_args()
if not ret or len(unknown) >= 1:
logger.log_and_raise(
TypeError,
'Registered benchmark has invalid arguments - benchmark: {}, parameters: {}'.format(name, parameters)
)
else:
cls.benchmarks[name]['predefine_param'] = vars(args)
logger.info('Benchmark registration - benchmark: {}, predefine_parameters: {}'.format(name, vars(args)))
@classmethod @classmethod
def is_benchmark_context_valid(cls, benchmark_context): def is_benchmark_context_valid(cls, benchmark_context):
"""Check wether the benchmark context is valid or not. """Check wether the benchmark context is valid or not.
...@@ -143,6 +155,19 @@ def get_benchmark_configurable_settings(cls, benchmark_context): ...@@ -143,6 +155,19 @@ def get_benchmark_configurable_settings(cls, benchmark_context):
else: else:
return None return None
@classmethod
def get_all_benchmark_predefine_settings(cls):
"""Get all registered benchmarks' predefine settings.
Return:
benchmark_params (dict[str, dict]): key is benchmark name,
value is the dict with structure: {'parameter': default_value}.
"""
benchmark_params = dict()
for name in cls.benchmarks:
benchmark_params[name] = cls.benchmarks[name]['predefine_param']
return benchmark_params
@classmethod @classmethod
def launch_benchmark(cls, benchmark_context): def launch_benchmark(cls, benchmark_context):
"""Select and Launch benchmark. """Select and Launch benchmark.
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
from superbench.benchmarks import Platform, Framework, BenchmarkType, BenchmarkRegistry, ReturnCode from superbench.benchmarks import Platform, Framework, BenchmarkType, BenchmarkRegistry, ReturnCode
from superbench.benchmarks.micro_benchmarks import MicroBenchmark from superbench.benchmarks.micro_benchmarks import MicroBenchmark
from superbench.benchmarks.micro_benchmarks.sharding_matmul import ShardingMode
class AccumulationBenchmark(MicroBenchmark): class AccumulationBenchmark(MicroBenchmark):
...@@ -196,3 +197,20 @@ def test_launch_benchmark(): ...@@ -196,3 +197,20 @@ def test_launch_benchmark():
benchmark = BenchmarkRegistry.launch_benchmark(context) benchmark = BenchmarkRegistry.launch_benchmark(context)
assert (benchmark) assert (benchmark)
assert (benchmark.return_code == ReturnCode.INVALID_ARGUMENT) assert (benchmark.return_code == ReturnCode.INVALID_ARGUMENT)
def test_get_all_benchmark_predefine_settings():
"""Test interface BenchmarkRegistry.get_all_benchmark_predefine_settings()."""
benchmark_params = BenchmarkRegistry.get_all_benchmark_predefine_settings()
# Choose benchmark 'pytorch-sharding-matmul' for testing.
benchmark_name = 'pytorch-sharding-matmul'
assert (benchmark_name in benchmark_params)
assert (benchmark_params[benchmark_name]['run_count'] == 1)
assert (benchmark_params[benchmark_name]['duration'] == 0)
assert (benchmark_params[benchmark_name]['n'] == 4096)
assert (benchmark_params[benchmark_name]['k'] == 4096)
assert (benchmark_params[benchmark_name]['m'] == 4096)
assert (benchmark_params[benchmark_name]['mode'] == [ShardingMode.ALLREDUCE, ShardingMode.ALLGATHER])
assert (benchmark_params[benchmark_name]['num_warmup'] == 10)
assert (benchmark_params[benchmark_name]['num_steps'] == 500)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment