Unverified Commit c00dc670 authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

Benchmarks: Add Feature - Add sample_count argument for ModelBenchmark. (#22)



* add sample_count argument.

* handle more condidatins.

* address comments.
Co-authored-by: default avatarGuoshuai Zhao <guzhao@microsoft.com>
parent 31b6f085
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""Module of the model-benchmark base class.""" """Module of the model-benchmark base class."""
import math
from abc import abstractmethod from abc import abstractmethod
from superbench.common.utils import logger from superbench.common.utils import logger
...@@ -76,6 +77,13 @@ def add_parser_arguments(self): ...@@ -76,6 +77,13 @@ def add_parser_arguments(self):
required=False, required=False,
help='The number of test step.', help='The number of test step.',
) )
self._parser.add_argument(
'--sample_count',
type=int,
default=128,
required=False,
help='The number of data samples in dataset.',
)
self._parser.add_argument( self._parser.add_argument(
'--batch_size', '--batch_size',
type=int, type=int,
...@@ -170,6 +178,9 @@ def _preprocess(self): ...@@ -170,6 +178,9 @@ def _preprocess(self):
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE) self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
return False return False
# Set sample_count aligned with batch_size.
self._args.sample_count = math.ceil(self._args.sample_count / self._args.batch_size) * self._args.batch_size
if not self._generate_dataset(): if not self._generate_dataset():
self._result.set_return_code(ReturnCode.DATASET_GENERATION_FAILURE) self._result.set_return_code(ReturnCode.DATASET_GENERATION_FAILURE)
return False return False
......
...@@ -22,12 +22,18 @@ def __init__(self, shape, world_size, dtype=torch.float): ...@@ -22,12 +22,18 @@ def __init__(self, shape, world_size, dtype=torch.float):
self._len = 0 self._len = 0
self._data = None self._data = None
if dtype in [torch.float32, torch.float64]: try:
self._data = torch.randn(*shape, dtype=dtype) if dtype in [torch.float32, torch.float64]:
elif dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: self._data = torch.randn(*shape, dtype=dtype)
self._data = torch.randint(0, 128, tuple(shape), dtype=dtype) elif dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
else: self._data = torch.randint(0, 128, tuple(shape), dtype=dtype)
logger.error('Unsupported precision for RandomDataset - data type: {}.'.format(dtype)) else:
logger.error('Unsupported precision for RandomDataset - data type: {}.'.format(dtype))
return
except BaseException as e:
logger.error(
'Generate random dataset failed - data type: {}, shape: {}, message: {}.'.format(dtype, shape, str(e))
)
return return
self._len = shape[0] * world_size self._len = shape[0] * world_size
......
...@@ -146,6 +146,7 @@ def test_arguments_related_interfaces(): ...@@ -146,6 +146,7 @@ def test_arguments_related_interfaces():
--duration int The elapsed time of benchmark in seconds. --duration int The elapsed time of benchmark in seconds.
--num_warmup int The number of warmup step. --num_warmup int The number of warmup step.
--num_steps int The number of test step. --num_steps int The number of test step.
--sample_count int The number of data samples in dataset.
--batch_size int The number of batch size. --batch_size int The number of batch size.
--precision Precision [Precision ...] --precision Precision [Precision ...]
Model precision. E.g. float16 float32 float64 bfloat16 Model precision. E.g. float16 float32 float64 bfloat16
...@@ -177,6 +178,7 @@ def test_preprocess(): ...@@ -177,6 +178,7 @@ def test_preprocess():
--duration int The elapsed time of benchmark in seconds. --duration int The elapsed time of benchmark in seconds.
--num_warmup int The number of warmup step. --num_warmup int The number of warmup step.
--num_steps int The number of test step. --num_steps int The number of test step.
--sample_count int The number of data samples in dataset.
--batch_size int The number of batch size. --batch_size int The number of batch size.
--precision Precision [Precision ...] --precision Precision [Precision ...]
Model precision. E.g. float16 float32 float64 bfloat16 Model precision. E.g. float16 float32 float64 bfloat16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment