caltech.py 8.8 KB
Newer Older
1
2
3
from PIL import Image
import os
import os.path
Philip Meier's avatar
Philip Meier committed
4
from typing import Any, Callable, List, Optional, Union, Tuple
5

6
from .vision import VisionDataset
7
from .utils import download_and_extract_archive, verify_str_arg
8
9


10
class Caltech101(VisionDataset):
11
12
    """`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.

13
14
15
16
    .. warning::

        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    Args:
        root (string): Root directory of dataset where directory
            ``caltech101`` exists or will be saved to if download is set to True.
        target_type (string or list, optional): Type of target to use, ``category`` or
        ``annotation``. Can also be a list to output a tuple with all specified target types.
        ``category`` represents the target class, and ``annotation`` is a list of points
        from a hand-generated outline. Defaults to ``category``.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

Philip Meier's avatar
Philip Meier committed
33
34
35
36
37
38
39
40
    def __init__(
            self,
            root: str,
            target_type: Union[List[str], str] = "category",
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            download: bool = False,
    ) -> None:
41
42
43
        super(Caltech101, self).__init__(os.path.join(root, 'caltech101'),
                                         transform=transform,
                                         target_transform=target_transform)
44
        os.makedirs(self.root, exist_ok=True)
45
46
47
48
        if not isinstance(target_type, list):
            target_type = [target_type]
        self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation"))
                            for t in target_type]
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
        self.categories.remove("BACKGROUND_Google")  # this is not a real class

        # For some reason, the category names in "101_ObjectCategories" and
        # "Annotations" do not always match. This is a manual map between the
        # two. Defaults to using same name, since most names are fine.
        name_map = {"Faces": "Faces_2",
                    "Faces_easy": "Faces_3",
                    "Motorbikes": "Motorbikes_16",
                    "airplanes": "Airplanes_Side_2"}
        self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))

Philip Meier's avatar
Philip Meier committed
69
        self.index: List[int] = []
70
71
72
73
74
75
        self.y = []
        for (i, c) in enumerate(self.categories):
            n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
76
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where the type of target specified by target_type.
        """
        import scipy.io

        img = Image.open(os.path.join(self.root,
                                      "101_ObjectCategories",
                                      self.categories[self.y[index]],
                                      "image_{:04d}.jpg".format(self.index[index])))

Philip Meier's avatar
Philip Meier committed
91
        target: Any = []
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        for t in self.target_type:
            if t == "category":
                target.append(self.y[index])
            elif t == "annotation":
                data = scipy.io.loadmat(os.path.join(self.root,
                                                     "Annotations",
                                                     self.annotation_categories[self.y[index]],
                                                     "annotation_{:04d}.mat".format(self.index[index])))
                target.append(data["obj_contour"])
        target = tuple(target) if len(target) > 1 else target[0]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
111
    def _check_integrity(self) -> bool:
112
113
114
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
115
    def __len__(self) -> int:
116
117
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
118
    def download(self) -> None:
119
120
121
122
        if self._check_integrity():
            print('Files already downloaded and verified')
            return

123
        download_and_extract_archive(
124
125
            "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
            self.root,
126
127
            filename="101_ObjectCategories.tar.gz",
            md5="b224c7392d521a49829488ab0f1120d9")
128
        download_and_extract_archive(
129
130
            "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
            self.root,
131
132
            filename="101_Annotations.tar",
            md5="6f83eeb1f24d99cab4eb377263132c91")
133

Philip Meier's avatar
Philip Meier committed
134
    def extra_repr(self) -> str:
135
        return "Target type: {target_type}".format(**self.__dict__)
136
137


138
class Caltech256(VisionDataset):
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    """`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``caltech256`` exists or will be saved to if download is set to True.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

Philip Meier's avatar
Philip Meier committed
153
154
155
156
157
158
159
    def __init__(
            self,
            root: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            download: bool = False,
    ) -> None:
160
161
162
        super(Caltech256, self).__init__(os.path.join(root, 'caltech256'),
                                         transform=transform,
                                         target_transform=target_transform)
163
        os.makedirs(self.root, exist_ok=True)
164
165
166
167
168
169
170
171
172

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
Philip Meier's avatar
Philip Meier committed
173
        self.index: List[int] = []
174
175
176
177
178
179
        self.y = []
        for (i, c) in enumerate(self.categories):
            n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c)))
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
180
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img = Image.open(os.path.join(self.root,
                                      "256_ObjectCategories",
                                      self.categories[self.y[index]],
                                      "{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index])))

        target = self.y[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
203
    def _check_integrity(self) -> bool:
204
205
206
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
207
    def __len__(self) -> int:
208
209
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
210
    def download(self) -> None:
211
212
213
214
        if self._check_integrity():
            print('Files already downloaded and verified')
            return

215
        download_and_extract_archive(
216
217
            "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
            self.root,
218
219
            filename="256_ObjectCategories.tar",
            md5="67b4f42ca05d46448c6bb8ecd2220f6d")