vision.py 4.07 KB
Newer Older
1
import os
2
3
from typing import Any, Callable, List, Optional, Tuple

4
5
6
import torch
import torch.utils.data as data

7
8
from ..utils import _log_api_usage_once

9
10

class VisionDataset(data.Dataset):
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
    """
    Base Class For making datasets which are compatible with torchvision.
    It is necessary to override the ``__getitem__`` and ``__len__`` method.

    Args:
        root (string): Root directory of dataset.
        transforms (callable, optional): A function/transforms that takes in
            an image and a label and returns the transformed versions of both.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.

    .. note::

        :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
    """
28

29
30
    _repr_indent = 4

Philip Meier's avatar
Philip Meier committed
31
    def __init__(
32
33
34
35
36
        self,
        root: str,
        transforms: Optional[Callable] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
Philip Meier's avatar
Philip Meier committed
37
    ) -> None:
38
        _log_api_usage_once(self)
39
40
41
42
        if isinstance(root, torch._six.string_classes):
            root = os.path.expanduser(root)
        self.root = root

43
44
45
        has_transforms = transforms is not None
        has_separate_transform = transform is not None or target_transform is not None
        if has_transforms and has_separate_transform:
46
            raise ValueError("Only transforms or transform/target_transform can be passed as argument")
47
48
49
50
51
52
53
54
55

        # for backwards-compatibility
        self.transform = transform
        self.target_transform = target_transform

        if has_separate_transform:
            transforms = StandardTransform(transform, target_transform)
        self.transforms = transforms

Philip Meier's avatar
Philip Meier committed
56
    def __getitem__(self, index: int) -> Any:
57
58
59
60
61
62
63
        """
        Args:
            index (int): Index

        Returns:
            (Any): Sample and meta data, optionally transformed by the respective transforms.
        """
64
65
        raise NotImplementedError

Philip Meier's avatar
Philip Meier committed
66
    def __len__(self) -> int:
67
68
        raise NotImplementedError

Philip Meier's avatar
Philip Meier committed
69
    def __repr__(self) -> str:
70
        head = "Dataset " + self.__class__.__name__
71
        body = [f"Number of datapoints: {self.__len__()}"]
72
        if self.root is not None:
73
            body.append(f"Root location: {self.root}")
74
        body += self.extra_repr().splitlines()
Francisco Massa's avatar
Francisco Massa committed
75
        if hasattr(self, "transforms") and self.transforms is not None:
76
            body += [repr(self.transforms)]
77
        lines = [head] + [" " * self._repr_indent + line for line in body]
78
        return "\n".join(lines)
79

Philip Meier's avatar
Philip Meier committed
80
    def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
81
        lines = transform.__repr__().splitlines()
82
        return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
83

Philip Meier's avatar
Philip Meier committed
84
    def extra_repr(self) -> str:
85
        return ""
86
87


88
class StandardTransform:
Philip Meier's avatar
Philip Meier committed
89
    def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
90
91
92
        self.transform = transform
        self.target_transform = target_transform

Philip Meier's avatar
Philip Meier committed
93
    def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
94
95
96
97
98
99
        if self.transform is not None:
            input = self.transform(input)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return input, target

Philip Meier's avatar
Philip Meier committed
100
    def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
101
        lines = transform.__repr__().splitlines()
102
        return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
103

Philip Meier's avatar
Philip Meier committed
104
    def __repr__(self) -> str:
105
106
        body = [self.__class__.__name__]
        if self.transform is not None:
107
            body += self._format_transform_repr(self.transform, "Transform: ")
108
        if self.target_transform is not None:
109
            body += self._format_transform_repr(self.target_transform, "Target transform: ")
110

111
        return "\n".join(body)