caltech.py 8.58 KB
Newer Older
1
2
import os
import os.path
Philip Meier's avatar
Philip Meier committed
3
from typing import Any, Callable, List, Optional, Union, Tuple
4

5
6
from PIL import Image

7
from .utils import download_and_extract_archive, verify_str_arg
8
from .vision import VisionDataset
9
10


11
class Caltech101(VisionDataset):
12
13
    """`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.

14
15
16
17
    .. warning::

        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.

18
19
20
21
    Args:
        root (string): Root directory of dataset where directory
            ``caltech101`` exists or will be saved to if download is set to True.
        target_type (string or list, optional): Type of target to use, ``category`` or
22
23
24
25
            ``annotation``. Can also be a list to output a tuple with all specified
            target types.  ``category`` represents the target class, and
            ``annotation`` is a list of points from a hand-generated outline.
            Defaults to ``category``.
26
27
28
29
30
31
32
33
34
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

Philip Meier's avatar
Philip Meier committed
35
    def __init__(
36
37
38
39
40
41
        self,
        root: str,
        target_type: Union[List[str], str] = "category",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
42
    ) -> None:
43
44
45
        super(Caltech101, self).__init__(
            os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform
        )
46
        os.makedirs(self.root, exist_ok=True)
47
48
        if not isinstance(target_type, list):
            target_type = [target_type]
49
        self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
50
51
52
53
54

        if download:
            self.download()

        if not self._check_integrity():
55
            raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
56
57
58
59
60
61
62

        self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
        self.categories.remove("BACKGROUND_Google")  # this is not a real class

        # For some reason, the category names in "101_ObjectCategories" and
        # "Annotations" do not always match. This is a manual map between the
        # two. Defaults to using same name, since most names are fine.
63
64
65
66
67
68
        name_map = {
            "Faces": "Faces_2",
            "Faces_easy": "Faces_3",
            "Motorbikes": "Motorbikes_16",
            "airplanes": "Airplanes_Side_2",
        }
69
70
        self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))

Philip Meier's avatar
Philip Meier committed
71
        self.index: List[int] = []
72
73
74
75
76
77
        self.y = []
        for (i, c) in enumerate(self.categories):
            n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
78
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
79
80
81
82
83
84
85
86
87
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where the type of target specified by target_type.
        """
        import scipy.io

88
89
90
91
92
93
94
95
        img = Image.open(
            os.path.join(
                self.root,
                "101_ObjectCategories",
                self.categories[self.y[index]],
                "image_{:04d}.jpg".format(self.index[index]),
            )
        )
96

Philip Meier's avatar
Philip Meier committed
97
        target: Any = []
98
99
100
101
        for t in self.target_type:
            if t == "category":
                target.append(self.y[index])
            elif t == "annotation":
102
103
104
105
106
107
108
109
                data = scipy.io.loadmat(
                    os.path.join(
                        self.root,
                        "Annotations",
                        self.annotation_categories[self.y[index]],
                        "annotation_{:04d}.mat".format(self.index[index]),
                    )
                )
110
111
112
113
114
115
116
117
118
119
120
                target.append(data["obj_contour"])
        target = tuple(target) if len(target) > 1 else target[0]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
121
    def _check_integrity(self) -> bool:
122
123
124
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
125
    def __len__(self) -> int:
126
127
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
128
    def download(self) -> None:
129
        if self._check_integrity():
130
            print("Files already downloaded and verified")
131
132
            return

133
        download_and_extract_archive(
134
135
            "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
            self.root,
136
            filename="101_ObjectCategories.tar.gz",
137
138
            md5="b224c7392d521a49829488ab0f1120d9",
        )
139
        download_and_extract_archive(
140
141
            "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
            self.root,
142
            filename="101_Annotations.tar",
143
144
            md5="6f83eeb1f24d99cab4eb377263132c91",
        )
145

Philip Meier's avatar
Philip Meier committed
146
    def extra_repr(self) -> str:
147
        return "Target type: {target_type}".format(**self.__dict__)
148
149


150
class Caltech256(VisionDataset):
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    """`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``caltech256`` exists or will be saved to if download is set to True.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

Philip Meier's avatar
Philip Meier committed
165
    def __init__(
166
167
168
169
170
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
171
    ) -> None:
172
173
174
        super(Caltech256, self).__init__(
            os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform
        )
175
        os.makedirs(self.root, exist_ok=True)
176
177
178
179
180

        if download:
            self.download()

        if not self._check_integrity():
181
            raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
182
183

        self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
Philip Meier's avatar
Philip Meier committed
184
        self.index: List[int] = []
185
186
187
188
189
190
        self.y = []
        for (i, c) in enumerate(self.categories):
            n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c)))
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
191
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
192
193
194
195
196
197
198
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
199
200
201
202
203
204
205
206
        img = Image.open(
            os.path.join(
                self.root,
                "256_ObjectCategories",
                self.categories[self.y[index]],
                "{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]),
            )
        )
207
208
209
210
211
212
213
214
215
216
217

        target = self.y[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
218
    def _check_integrity(self) -> bool:
219
220
221
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
222
    def __len__(self) -> int:
223
224
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
225
    def download(self) -> None:
226
        if self._check_integrity():
227
            print("Files already downloaded and verified")
228
229
            return

230
        download_and_extract_archive(
231
232
            "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
            self.root,
233
            filename="256_ObjectCategories.tar",
234
235
            md5="67b4f42ca05d46448c6bb8ecd2220f6d",
        )