random_dataset.py 1.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module to define random Dataset."""

import torch
from torch.utils.data import Dataset

from superbench.common.utils import logger


class TorchRandomDataset(Dataset):
    """Dataset that can generate the input data randomly."""
    def __init__(self, shape, world_size, dtype=torch.float):
        """Constructor.

        Args:
            shape (List[int]): Shape of dataset.
            world_size (int): Number of workers.
            dtype (torch.dtype): Type of the elements.
        """
        self._len = 0
        self._data = None

        if dtype in [torch.float32, torch.float64]:
            self._data = torch.randn(*shape, dtype=dtype)
        elif dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
            self._data = torch.randint(0, 128, tuple(shape), dtype=dtype)
        else:
            logger.error('Unsupported precision for RandomDataset - data type: {}.'.format(dtype))
            return

        self._len = shape[0] * world_size
        self._world_size = world_size

    def __getitem__(self, index):
        """Get the element according to index.

        Args:
            index (int): Position index.

        Return:
            Element in dataset.
        """
        return self._data[int(index / self._world_size)]

    def __len__(self):
        """Get the count of elements.

        Return:
            The count of elements.
        """
        return self._len