Unverified Commit c00dc670 authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

Benchmarks: Add Feature - Add sample_count argument for ModelBenchmark. (#22)



* add sample_count argument.

* handle more condidatins.

* address comments.
Co-authored-by: default avatarGuoshuai Zhao <guzhao@microsoft.com>
parent 31b6f085
......@@ -3,6 +3,7 @@
"""Module of the model-benchmark base class."""
import math
from abc import abstractmethod
from superbench.common.utils import logger
......@@ -76,6 +77,13 @@ def add_parser_arguments(self):
required=False,
help='The number of test step.',
)
self._parser.add_argument(
'--sample_count',
type=int,
default=128,
required=False,
help='The number of data samples in dataset.',
)
self._parser.add_argument(
'--batch_size',
type=int,
......@@ -170,6 +178,9 @@ def _preprocess(self):
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
return False
# Set sample_count aligned with batch_size.
self._args.sample_count = math.ceil(self._args.sample_count / self._args.batch_size) * self._args.batch_size
if not self._generate_dataset():
self._result.set_return_code(ReturnCode.DATASET_GENERATION_FAILURE)
return False
......
......@@ -22,12 +22,18 @@ def __init__(self, shape, world_size, dtype=torch.float):
self._len = 0
self._data = None
if dtype in [torch.float32, torch.float64]:
self._data = torch.randn(*shape, dtype=dtype)
elif dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
self._data = torch.randint(0, 128, tuple(shape), dtype=dtype)
else:
logger.error('Unsupported precision for RandomDataset - data type: {}.'.format(dtype))
try:
if dtype in [torch.float32, torch.float64]:
self._data = torch.randn(*shape, dtype=dtype)
elif dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
self._data = torch.randint(0, 128, tuple(shape), dtype=dtype)
else:
logger.error('Unsupported precision for RandomDataset - data type: {}.'.format(dtype))
return
except BaseException as e:
logger.error(
'Generate random dataset failed - data type: {}, shape: {}, message: {}.'.format(dtype, shape, str(e))
)
return
self._len = shape[0] * world_size
......
......@@ -146,6 +146,7 @@ def test_arguments_related_interfaces():
--duration int The elapsed time of benchmark in seconds.
--num_warmup int The number of warmup step.
--num_steps int The number of test step.
--sample_count int The number of data samples in dataset.
--batch_size int The number of batch size.
--precision Precision [Precision ...]
Model precision. E.g. float16 float32 float64 bfloat16
......@@ -177,6 +178,7 @@ def test_preprocess():
--duration int The elapsed time of benchmark in seconds.
--num_warmup int The number of warmup step.
--num_steps int The number of test step.
--sample_count int The number of data samples in dataset.
--batch_size int The number of batch size.
--precision Precision [Precision ...]
Model precision. E.g. float16 float32 float64 bfloat16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment