random_dataset.py 1.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module to define random Dataset."""

import torch
from torch.utils.data import Dataset

from superbench.common.utils import logger


class TorchRandomDataset(Dataset):
    """Dataset that can generate the input data randomly."""
    def __init__(self, shape, world_size, dtype=torch.float):
        """Constructor.

        Args:
            shape (List[int]): Shape of dataset.
            world_size (int): Number of workers.
            dtype (torch.dtype): Type of the elements.
        """
        self._len = 0
        self._data = None

25
26
27
28
29
30
31
32
33
34
35
36
        try:
            if dtype in [torch.float32, torch.float64]:
                self._data = torch.randn(*shape, dtype=dtype)
            elif dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
                self._data = torch.randint(0, 128, tuple(shape), dtype=dtype)
            else:
                logger.error('Unsupported precision for RandomDataset - data type: {}.'.format(dtype))
                return
        except BaseException as e:
            logger.error(
                'Generate random dataset failed - data type: {}, shape: {}, message: {}.'.format(dtype, shape, str(e))
            )
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
            return

        self._len = shape[0] * world_size
        self._world_size = world_size

    def __getitem__(self, index):
        """Get the element according to index.

        Args:
            index (int): Position index.

        Return:
            Element in dataset.
        """
        return self._data[int(index / self._world_size)]

    def __len__(self):
        """Get the count of elements.

        Return:
            The count of elements.
        """
        return self._len