Unverified Commit ebea2d50 authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

Benchmarks: Add Feature - Add random dataset for Pytorch. (#17)



* add random dataset.

* install pytorch-cpu for test docker.

* fix typo

* add more test cases.

* address comments.
Co-authored-by: default avatarGuoshuai Zhao <guzhao@microsoft.com>
parent 52848d2f
...@@ -145,6 +145,11 @@ def run(self): ...@@ -145,6 +145,11 @@ def run(self):
'pytest>=6.2.2', 'pytest>=6.2.2',
'pytest-cov>=2.11.1', 'pytest-cov>=2.11.1',
], ],
'torch': [
'torch==1.7.0',
'torchvision==0.8.0',
'transformers==4.3.3',
],
}, },
package_data={}, package_data={},
entry_points={ entry_points={
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Module to define random Dataset."""
import torch
from torch.utils.data import Dataset
from superbench.common.utils import logger
class TorchRandomDataset(Dataset):
"""Dataset that can generate the input data randomly."""
def __init__(self, shape, world_size, dtype=torch.float):
"""Constructor.
Args:
shape (List[int]): Shape of dataset.
world_size (int): Number of workers.
dtype (torch.dtype): Type of the elements.
"""
self._len = 0
self._data = None
if dtype in [torch.float32, torch.float64]:
self._data = torch.randn(*shape, dtype=dtype)
elif dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
self._data = torch.randint(0, 128, tuple(shape), dtype=dtype)
else:
logger.error('Unsupported precision for RandomDataset - data type: {}.'.format(dtype))
return
self._len = shape[0] * world_size
self._world_size = world_size
def __getitem__(self, index):
"""Get the element according to index.
Args:
index (int): Position index.
Return:
Element in dataset.
"""
return self._data[int(index / self._world_size)]
def __len__(self):
"""Get the count of elements.
Return:
The count of elements.
"""
return self._len
...@@ -20,7 +20,7 @@ COPY . /superbench ...@@ -20,7 +20,7 @@ COPY . /superbench
# Upgrade pip and install dependencies # Upgrade pip and install dependencies
RUN python3 -m pip install --upgrade pip setuptools && \ RUN python3 -m pip install --upgrade pip setuptools && \
python3 -m pip install .[test] python3 -m pip install .[test,torch]
# Lint code # Lint code
RUN python3 setup.py lint RUN python3 setup.py lint
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tests for random_dataset module."""
import torch
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset
def test_torch_random_dataset():
"""Test TorchRandomDataset class."""
shape = [32, 64]
world_size = 1
supported_types = [torch.float32, torch.float64, torch.int8, torch.int16, torch.int32, torch.int64]
for dtype in supported_types:
dataset = TorchRandomDataset(shape, world_size, dtype=dtype)
assert (len(dataset) == 32)
assert (len(dataset[0]) == 64)
assert (dataset._data.dtype == dtype)
world_size = 2
for dtype in supported_types:
dataset = TorchRandomDataset(shape, world_size, dtype=dtype)
assert (len(dataset) == 64)
assert (len(dataset[0]) == 64)
assert (dataset._data.dtype == dtype)
# Case for unsupported data type.
dataset = TorchRandomDataset(shape, world_size, dtype=torch.float16)
assert (len(dataset) == 0)
assert (dataset._data is None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment