caltech.py 8.52 KB
Newer Older
1
2
import os
import os.path
Philip Meier's avatar
Philip Meier committed
3
from typing import Any, Callable, List, Optional, Union, Tuple
4

5
6
from PIL import Image

7
from .utils import download_and_extract_archive, verify_str_arg
8
from .vision import VisionDataset
9
10


11
class Caltech101(VisionDataset):
12
13
    """`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.

14
15
16
17
    .. warning::

        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.

18
19
20
21
    Args:
        root (string): Root directory of dataset where directory
            ``caltech101`` exists or will be saved to if download is set to True.
        target_type (string or list, optional): Type of target to use, ``category`` or
22
23
24
25
            ``annotation``. Can also be a list to output a tuple with all specified
            target types.  ``category`` represents the target class, and
            ``annotation`` is a list of points from a hand-generated outline.
            Defaults to ``category``.
26
27
28
29
30
31
32
33
34
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

Philip Meier's avatar
Philip Meier committed
35
    def __init__(
36
37
38
39
40
41
        self,
        root: str,
        target_type: Union[List[str], str] = "category",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
42
    ) -> None:
43
        super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
44
        os.makedirs(self.root, exist_ok=True)
45
        if isinstance(target_type, str):
46
            target_type = [target_type]
47
        self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
48
49
50
51
52

        if download:
            self.download()

        if not self._check_integrity():
53
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
54
55
56
57
58
59
60

        self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
        self.categories.remove("BACKGROUND_Google")  # this is not a real class

        # For some reason, the category names in "101_ObjectCategories" and
        # "Annotations" do not always match. This is a manual map between the
        # two. Defaults to using same name, since most names are fine.
61
62
63
64
65
66
        name_map = {
            "Faces": "Faces_2",
            "Faces_easy": "Faces_3",
            "Motorbikes": "Motorbikes_16",
            "airplanes": "Airplanes_Side_2",
        }
67
68
        self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))

Philip Meier's avatar
Philip Meier committed
69
        self.index: List[int] = []
70
71
72
73
74
75
        self.y = []
        for (i, c) in enumerate(self.categories):
            n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
76
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
77
78
79
80
81
82
83
84
85
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where the type of target specified by target_type.
        """
        import scipy.io

86
87
88
89
90
        img = Image.open(
            os.path.join(
                self.root,
                "101_ObjectCategories",
                self.categories[self.y[index]],
91
                f"image_{self.index[index]:04d}.jpg",
92
93
            )
        )
94

Philip Meier's avatar
Philip Meier committed
95
        target: Any = []
96
97
98
99
        for t in self.target_type:
            if t == "category":
                target.append(self.y[index])
            elif t == "annotation":
100
101
102
103
104
                data = scipy.io.loadmat(
                    os.path.join(
                        self.root,
                        "Annotations",
                        self.annotation_categories[self.y[index]],
105
                        f"annotation_{self.index[index]:04d}.mat",
106
107
                    )
                )
108
109
110
111
112
113
114
115
116
117
118
                target.append(data["obj_contour"])
        target = tuple(target) if len(target) > 1 else target[0]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
119
    def _check_integrity(self) -> bool:
120
121
122
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
123
    def __len__(self) -> int:
124
125
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
126
    def download(self) -> None:
127
        if self._check_integrity():
128
            print("Files already downloaded and verified")
129
130
            return

131
        download_and_extract_archive(
132
133
            "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
            self.root,
134
135
            md5="b224c7392d521a49829488ab0f1120d9",
        )
136
        download_and_extract_archive(
137
138
            "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
            self.root,
139
140
            md5="6f83eeb1f24d99cab4eb377263132c91",
        )
141

Philip Meier's avatar
Philip Meier committed
142
    def extra_repr(self) -> str:
143
        return "Target type: {target_type}".format(**self.__dict__)
144
145


146
class Caltech256(VisionDataset):
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    """`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``caltech256`` exists or will be saved to if download is set to True.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

Philip Meier's avatar
Philip Meier committed
161
    def __init__(
162
163
164
165
166
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
167
    ) -> None:
168
        super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
169
        os.makedirs(self.root, exist_ok=True)
170
171
172
173
174

        if download:
            self.download()

        if not self._check_integrity():
175
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
176
177

        self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
Philip Meier's avatar
Philip Meier committed
178
        self.index: List[int] = []
179
180
        self.y = []
        for (i, c) in enumerate(self.categories):
Philip Meier's avatar
Philip Meier committed
181
182
183
184
185
186
187
            n = len(
                [
                    item
                    for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
                    if item.endswith(".jpg")
                ]
            )
188
189
190
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
191
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
192
193
194
195
196
197
198
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
199
200
201
202
203
        img = Image.open(
            os.path.join(
                self.root,
                "256_ObjectCategories",
                self.categories[self.y[index]],
204
                f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
205
206
            )
        )
207
208
209
210
211
212
213
214
215
216
217

        target = self.y[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
218
    def _check_integrity(self) -> bool:
219
220
221
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
222
    def __len__(self) -> int:
223
224
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
225
    def download(self) -> None:
226
        if self._check_integrity():
227
            print("Files already downloaded and verified")
228
229
            return

230
        download_and_extract_archive(
231
232
            "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
            self.root,
233
            filename="256_ObjectCategories.tar",
234
235
            md5="67b4f42ca05d46448c6bb8ecd2220f6d",
        )