caltech.py 8.63 KB
Newer Older
1
2
import os
import os.path
Philip Meier's avatar
Philip Meier committed
3
from typing import Any, Callable, List, Optional, Union, Tuple
4

5
6
from PIL import Image

7
from .utils import download_and_extract_archive, verify_str_arg
8
from .vision import VisionDataset
9
10


11
class Caltech101(VisionDataset):
12
13
    """`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.

14
15
16
17
    .. warning::

        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.

18
19
20
21
    Args:
        root (string): Root directory of dataset where directory
            ``caltech101`` exists or will be saved to if download is set to True.
        target_type (string or list, optional): Type of target to use, ``category`` or
22
23
24
25
            ``annotation``. Can also be a list to output a tuple with all specified
            target types.  ``category`` represents the target class, and
            ``annotation`` is a list of points from a hand-generated outline.
            Defaults to ``category``.
26
27
28
29
30
31
32
33
34
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

Philip Meier's avatar
Philip Meier committed
35
    def __init__(
36
37
38
39
40
41
        self,
        root: str,
        target_type: Union[List[str], str] = "category",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
42
    ) -> None:
43
44
45
        super(Caltech101, self).__init__(
            os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform
        )
46
        os.makedirs(self.root, exist_ok=True)
47
        if isinstance(target_type, str):
48
            target_type = [target_type]
49
        self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
50
51
52
53
54

        if download:
            self.download()

        if not self._check_integrity():
55
            raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
56
57
58
59
60
61
62

        self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
        self.categories.remove("BACKGROUND_Google")  # this is not a real class

        # For some reason, the category names in "101_ObjectCategories" and
        # "Annotations" do not always match. This is a manual map between the
        # two. Defaults to using same name, since most names are fine.
63
64
65
66
67
68
        name_map = {
            "Faces": "Faces_2",
            "Faces_easy": "Faces_3",
            "Motorbikes": "Motorbikes_16",
            "airplanes": "Airplanes_Side_2",
        }
69
70
        self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))

Philip Meier's avatar
Philip Meier committed
71
        self.index: List[int] = []
72
73
74
75
76
77
        self.y = []
        for (i, c) in enumerate(self.categories):
            n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
78
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
79
80
81
82
83
84
85
86
87
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where the type of target specified by target_type.
        """
        import scipy.io

88
89
90
91
92
93
94
95
        img = Image.open(
            os.path.join(
                self.root,
                "101_ObjectCategories",
                self.categories[self.y[index]],
                "image_{:04d}.jpg".format(self.index[index]),
            )
        )
96

Philip Meier's avatar
Philip Meier committed
97
        target: Any = []
98
99
100
101
        for t in self.target_type:
            if t == "category":
                target.append(self.y[index])
            elif t == "annotation":
102
103
104
105
106
107
108
109
                data = scipy.io.loadmat(
                    os.path.join(
                        self.root,
                        "Annotations",
                        self.annotation_categories[self.y[index]],
                        "annotation_{:04d}.mat".format(self.index[index]),
                    )
                )
110
111
112
113
114
115
116
117
118
119
120
                target.append(data["obj_contour"])
        target = tuple(target) if len(target) > 1 else target[0]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
121
    def _check_integrity(self) -> bool:
122
123
124
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
125
    def __len__(self) -> int:
126
127
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
128
    def download(self) -> None:
129
        if self._check_integrity():
130
            print("Files already downloaded and verified")
131
132
            return

133
        download_and_extract_archive(
134
135
            "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
            self.root,
136
137
            md5="b224c7392d521a49829488ab0f1120d9",
        )
138
        download_and_extract_archive(
139
140
            "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
            self.root,
141
142
            md5="6f83eeb1f24d99cab4eb377263132c91",
        )
143

Philip Meier's avatar
Philip Meier committed
144
    def extra_repr(self) -> str:
145
        return "Target type: {target_type}".format(**self.__dict__)
146
147


148
class Caltech256(VisionDataset):
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    """`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``caltech256`` exists or will be saved to if download is set to True.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

Philip Meier's avatar
Philip Meier committed
163
    def __init__(
164
165
166
167
168
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
169
    ) -> None:
170
171
172
        super(Caltech256, self).__init__(
            os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform
        )
173
        os.makedirs(self.root, exist_ok=True)
174
175
176
177
178

        if download:
            self.download()

        if not self._check_integrity():
179
            raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
180
181

        self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
Philip Meier's avatar
Philip Meier committed
182
        self.index: List[int] = []
183
184
        self.y = []
        for (i, c) in enumerate(self.categories):
Philip Meier's avatar
Philip Meier committed
185
186
187
188
189
190
191
            n = len(
                [
                    item
                    for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
                    if item.endswith(".jpg")
                ]
            )
192
193
194
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
195
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
196
197
198
199
200
201
202
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
203
204
205
206
207
208
209
210
        img = Image.open(
            os.path.join(
                self.root,
                "256_ObjectCategories",
                self.categories[self.y[index]],
                "{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]),
            )
        )
211
212
213
214
215
216
217
218
219
220
221

        target = self.y[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
222
    def _check_integrity(self) -> bool:
223
224
225
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
226
    def __len__(self) -> int:
227
228
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
229
    def download(self) -> None:
230
        if self._check_integrity():
231
            print("Files already downloaded and verified")
232
233
            return

234
        download_and_extract_archive(
235
236
            "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
            self.root,
237
            filename="256_ObjectCategories.tar",
238
239
            md5="67b4f42ca05d46448c6bb8ecd2220f6d",
        )