vision.py 4.15 KB
Newer Older
1
import os
limm's avatar
limm committed
2
3
4
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

5
import torch.utils.data as data
limm's avatar
limm committed
6
7

from ..utils import _log_api_usage_once
8
9
10


class VisionDataset(data.Dataset):
limm's avatar
limm committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    """
    Base Class For making datasets which are compatible with torchvision.
    It is necessary to override the ``__getitem__`` and ``__len__`` method.

    Args:
        root (string, optional): Root directory of dataset. Only used for `__repr__`.
        transforms (callable, optional): A function/transforms that takes in
            an image and a label and returns the transformed versions of both.
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.

    .. note::

        :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
    """

29
30
    _repr_indent = 4

Philip Meier's avatar
Philip Meier committed
31
    def __init__(
limm's avatar
limm committed
32
33
34
35
36
        self,
        root: Union[str, Path] = None,  # type: ignore[assignment]
        transforms: Optional[Callable] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
Philip Meier's avatar
Philip Meier committed
37
    ) -> None:
limm's avatar
limm committed
38
39
        _log_api_usage_once(self)
        if isinstance(root, str):
40
41
42
            root = os.path.expanduser(root)
        self.root = root

43
44
45
        has_transforms = transforms is not None
        has_separate_transform = transform is not None or target_transform is not None
        if has_transforms and has_separate_transform:
limm's avatar
limm committed
46
            raise ValueError("Only transforms or transform/target_transform can be passed as argument")
47
48
49
50
51
52
53
54
55

        # for backwards-compatibility
        self.transform = transform
        self.target_transform = target_transform

        if has_separate_transform:
            transforms = StandardTransform(transform, target_transform)
        self.transforms = transforms

Philip Meier's avatar
Philip Meier committed
56
    def __getitem__(self, index: int) -> Any:
limm's avatar
limm committed
57
58
59
60
61
62
63
        """
        Args:
            index (int): Index

        Returns:
            (Any): Sample and meta data, optionally transformed by the respective transforms.
        """
64
65
        raise NotImplementedError

Philip Meier's avatar
Philip Meier committed
66
    def __len__(self) -> int:
67
68
        raise NotImplementedError

Philip Meier's avatar
Philip Meier committed
69
    def __repr__(self) -> str:
70
        head = "Dataset " + self.__class__.__name__
limm's avatar
limm committed
71
        body = [f"Number of datapoints: {self.__len__()}"]
72
        if self.root is not None:
limm's avatar
limm committed
73
            body.append(f"Root location: {self.root}")
74
        body += self.extra_repr().splitlines()
Francisco Massa's avatar
Francisco Massa committed
75
        if hasattr(self, "transforms") and self.transforms is not None:
76
            body += [repr(self.transforms)]
77
        lines = [head] + [" " * self._repr_indent + line for line in body]
limm's avatar
limm committed
78
        return "\n".join(lines)
79

Philip Meier's avatar
Philip Meier committed
80
    def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
81
        lines = transform.__repr__().splitlines()
limm's avatar
limm committed
82
        return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
83

Philip Meier's avatar
Philip Meier committed
84
    def extra_repr(self) -> str:
85
        return ""
86
87


limm's avatar
limm committed
88
class StandardTransform:
Philip Meier's avatar
Philip Meier committed
89
    def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
90
91
92
        self.transform = transform
        self.target_transform = target_transform

Philip Meier's avatar
Philip Meier committed
93
    def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
94
95
96
97
98
99
        if self.transform is not None:
            input = self.transform(input)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return input, target

Philip Meier's avatar
Philip Meier committed
100
    def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
101
        lines = transform.__repr__().splitlines()
limm's avatar
limm committed
102
        return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
103

Philip Meier's avatar
Philip Meier committed
104
    def __repr__(self) -> str:
105
106
        body = [self.__class__.__name__]
        if self.transform is not None:
limm's avatar
limm committed
107
            body += self._format_transform_repr(self.transform, "Transform: ")
108
        if self.target_transform is not None:
limm's avatar
limm committed
109
            body += self._format_transform_repr(self.target_transform, "Target transform: ")
110

limm's avatar
limm committed
111
        return "\n".join(body)