vision.py 3.25 KB
Newer Older
1
2
3
import os
import torch
import torch.utils.data as data
Philip Meier's avatar
Philip Meier committed
4
from typing import Any, Callable, List, Optional, Tuple
5
6
7
8
9


class VisionDataset(data.Dataset):
    _repr_indent = 4

Philip Meier's avatar
Philip Meier committed
10
11
12
13
14
15
16
    def __init__(
            self,
            root: str,
            transforms: Optional[Callable] = None,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
    ) -> None:
17
18
19
20
        if isinstance(root, torch._six.string_classes):
            root = os.path.expanduser(root)
        self.root = root

21
22
23
24
25
26
27
28
29
30
31
32
33
34
        has_transforms = transforms is not None
        has_separate_transform = transform is not None or target_transform is not None
        if has_transforms and has_separate_transform:
            raise ValueError("Only transforms or transform/target_transform can "
                             "be passed as argument")

        # for backwards-compatibility
        self.transform = transform
        self.target_transform = target_transform

        if has_separate_transform:
            transforms = StandardTransform(transform, target_transform)
        self.transforms = transforms

Philip Meier's avatar
Philip Meier committed
35
    def __getitem__(self, index: int) -> Any:
36
37
        raise NotImplementedError

Philip Meier's avatar
Philip Meier committed
38
    def __len__(self) -> int:
39
40
        raise NotImplementedError

Philip Meier's avatar
Philip Meier committed
41
    def __repr__(self) -> str:
42
43
44
45
46
        head = "Dataset " + self.__class__.__name__
        body = ["Number of datapoints: {}".format(self.__len__())]
        if self.root is not None:
            body.append("Root location: {}".format(self.root))
        body += self.extra_repr().splitlines()
Francisco Massa's avatar
Francisco Massa committed
47
        if hasattr(self, "transforms") and self.transforms is not None:
48
            body += [repr(self.transforms)]
49
50
51
        lines = [head] + [" " * self._repr_indent + line for line in body]
        return '\n'.join(lines)

Philip Meier's avatar
Philip Meier committed
52
    def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
53
54
55
56
        lines = transform.__repr__().splitlines()
        return (["{}{}".format(head, lines[0])] +
                ["{}{}".format(" " * len(head), line) for line in lines[1:]])

Philip Meier's avatar
Philip Meier committed
57
    def extra_repr(self) -> str:
58
        return ""
59
60
61


class StandardTransform(object):
Philip Meier's avatar
Philip Meier committed
62
    def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
63
64
65
        self.transform = transform
        self.target_transform = target_transform

Philip Meier's avatar
Philip Meier committed
66
    def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
67
68
69
70
71
72
        if self.transform is not None:
            input = self.transform(input)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return input, target

Philip Meier's avatar
Philip Meier committed
73
    def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
74
75
76
77
        lines = transform.__repr__().splitlines()
        return (["{}{}".format(head, lines[0])] +
                ["{}{}".format(" " * len(head), line) for line in lines[1:]])

Philip Meier's avatar
Philip Meier committed
78
    def __repr__(self) -> str:
79
80
81
82
83
84
85
86
87
        body = [self.__class__.__name__]
        if self.transform is not None:
            body += self._format_transform_repr(self.transform,
                                                "Transform: ")
        if self.target_transform is not None:
            body += self._format_transform_repr(self.target_transform,
                                                "Target transform: ")

        return '\n'.join(body)