Unverified Commit fb850af7 authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

Benchmarks: Add Feature - Add interface to get all predefine parameters of all benchmarks. (#56)

* Benchmarks: Add Feature - Add interface to get all predefine parameters of all benchmarks.
parent 435b2d5e
......@@ -23,7 +23,7 @@ def __init__(self, name, parameters=''):
parameters (str): benchmark parameters.
"""
self._name = name
self._argv = list(filter(None, parameters.split(' ')))
self._argv = list(filter(None, parameters.split(' '))) if parameters is not None else list()
self._benchmark_type = None
self._parser = argparse.ArgumentParser(
add_help=False,
......
......@@ -23,7 +23,7 @@ class BenchmarkRegistry:
benchmarks: Dict[str, dict] = dict()
@classmethod
def register_benchmark(cls, name, class_def, parameters=None, platform=None):
def register_benchmark(cls, name, class_def, parameters='', platform=None):
"""Register new benchmark, key is the benchmark name.
Args:
......@@ -67,6 +67,18 @@ def register_benchmark(cls, name, class_def, parameters=None, platform=None):
cls.benchmarks[name][p] = (class_def, parameters)
benchmark = class_def(name, parameters)
benchmark.add_parser_arguments()
ret, args, unknown = benchmark.parse_args()
if not ret or len(unknown) >= 1:
logger.log_and_raise(
TypeError,
'Registered benchmark has invalid arguments - benchmark: {}, parameters: {}'.format(name, parameters)
)
else:
cls.benchmarks[name]['predefine_param'] = vars(args)
logger.info('Benchmark registration - benchmark: {}, predefine_parameters: {}'.format(name, vars(args)))
@classmethod
def is_benchmark_context_valid(cls, benchmark_context):
"""Check wether the benchmark context is valid or not.
......@@ -143,6 +155,19 @@ def get_benchmark_configurable_settings(cls, benchmark_context):
else:
return None
@classmethod
def get_all_benchmark_predefine_settings(cls):
"""Get all registered benchmarks' predefine settings.
Return:
benchmark_params (dict[str, dict]): key is benchmark name,
value is the dict with structure: {'parameter': default_value}.
"""
benchmark_params = dict()
for name in cls.benchmarks:
benchmark_params[name] = cls.benchmarks[name]['predefine_param']
return benchmark_params
@classmethod
def launch_benchmark(cls, benchmark_context):
"""Select and Launch benchmark.
......
......@@ -7,6 +7,7 @@
from superbench.benchmarks import Platform, Framework, BenchmarkType, BenchmarkRegistry, ReturnCode
from superbench.benchmarks.micro_benchmarks import MicroBenchmark
from superbench.benchmarks.micro_benchmarks.sharding_matmul import ShardingMode
class AccumulationBenchmark(MicroBenchmark):
......@@ -196,3 +197,20 @@ def test_launch_benchmark():
benchmark = BenchmarkRegistry.launch_benchmark(context)
assert (benchmark)
assert (benchmark.return_code == ReturnCode.INVALID_ARGUMENT)
def test_get_all_benchmark_predefine_settings():
"""Test interface BenchmarkRegistry.get_all_benchmark_predefine_settings()."""
benchmark_params = BenchmarkRegistry.get_all_benchmark_predefine_settings()
# Choose benchmark 'pytorch-sharding-matmul' for testing.
benchmark_name = 'pytorch-sharding-matmul'
assert (benchmark_name in benchmark_params)
assert (benchmark_params[benchmark_name]['run_count'] == 1)
assert (benchmark_params[benchmark_name]['duration'] == 0)
assert (benchmark_params[benchmark_name]['n'] == 4096)
assert (benchmark_params[benchmark_name]['k'] == 4096)
assert (benchmark_params[benchmark_name]['m'] == 4096)
assert (benchmark_params[benchmark_name]['mode'] == [ShardingMode.ALLREDUCE, ShardingMode.ALLGATHER])
assert (benchmark_params[benchmark_name]['num_warmup'] == 10)
assert (benchmark_params[benchmark_name]['num_steps'] == 500)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment