folder.py 11.5 KB
Newer Older
soumith's avatar
soumith committed
1
2
import os
import os.path
Philip Meier's avatar
Philip Meier committed
3
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
soumith's avatar
soumith committed
4

5
6
7
8
from PIL import Image

from .vision import VisionDataset

9

Philip Meier's avatar
Philip Meier committed
10
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
11
    """Checks if a file is an allowed extension.
12
13
14

    Args:
        filename (string): path to a file
15
        extensions (tuple of strings): extensions to consider (lowercase)
16
17

    Returns:
18
        bool: True if the filename ends with one of given extensions
19
    """
20
    return filename.lower().endswith(extensions)
soumith's avatar
soumith committed
21

22

Philip Meier's avatar
Philip Meier committed
23
def is_image_file(filename: str) -> bool:
24
25
26
27
28
29
30
31
32
33
34
    """Checks if a file is an allowed image extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    return has_file_allowed_extension(filename, IMG_EXTENSIONS)


35
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
36
    """Finds the class folders in a dataset.
37

38
    See :class:`DatasetFolder` for details.
39
40
41
42
43
44
45
46
47
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx


Philip Meier's avatar
Philip Meier committed
48
49
def make_dataset(
    directory: str,
50
    class_to_idx: Optional[Dict[str, int]] = None,
Philip Meier's avatar
Philip Meier committed
51
52
53
    extensions: Optional[Tuple[str, ...]] = None,
    is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
54
55
    """Generates a list of samples of a form (path_to_sample, class).

56
    See :class:`DatasetFolder` for details.
57

58
59
    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
    by default.
60
    """
61
    directory = os.path.expanduser(directory)
62
63
64
65
66
67

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

68
69
70
    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
Surgan Jandial's avatar
Surgan Jandial committed
71
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
72

73
    if extensions is not None:
74

Philip Meier's avatar
Philip Meier committed
75
76
        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
77

Philip Meier's avatar
Philip Meier committed
78
    is_valid_file = cast(Callable[[str], bool], is_valid_file)
79
80
81

    instances = []
    available_classes = set()
82
83
84
85
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
soumith's avatar
soumith committed
86
            continue
87
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
88
            for fname in sorted(fnames):
IgorSusmelj's avatar
IgorSusmelj committed
89
90
                if is_valid_file(fname):
                    path = os.path.join(root, fname)
91
92
                    item = path, class_index
                    instances.append(item)
93
94
95
96

                    if target_class not in available_classes:
                        available_classes.add(target_class)

97
    empty_classes = set(class_to_idx.keys()) - available_classes
98
99
100
101
102
103
    if empty_classes:
        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        if extensions is not None:
            msg += f"Supported extensions are: {', '.join(extensions)}"
        raise FileNotFoundError(msg)

104
    return instances
soumith's avatar
soumith committed
105

106

107
class DatasetFolder(VisionDataset):
108
    """A generic data loader.
109

110
111
    This default directory structure can be customized by overriding the
    :meth:`find_classes` method.
112
113
114
115

    Args:
        root (string): Root directory path.
        loader (callable): A function to load a sample given its path.
116
        extensions (tuple[string]): A list of allowed extensions.
117
            both extensions and is_valid_file should not be passed.
118
119
120
121
122
        transform (callable, optional): A function/transform that takes in
            a sample and returns a transformed version.
            E.g, ``transforms.RandomCrop`` for images.
        target_transform (callable, optional): A function/transform that takes
            in the target and transforms it.
Carrie's avatar
Carrie committed
123
124
        is_valid_file (callable, optional): A function that takes path of a file
            and check if the file is a valid file (used to check of corrupt files)
125
            both extensions and is_valid_file should not be passed.
126
127

     Attributes:
128
        classes (list): List of the class names sorted alphabetically.
129
130
        class_to_idx (dict): Dict with items (class_name, class_index).
        samples (list): List of (sample path, class_index) tuples
131
        targets (list): The class_index value for each image in the dataset
132
133
    """

Philip Meier's avatar
Philip Meier committed
134
    def __init__(
135
136
137
138
139
140
141
        self,
        root: str,
        loader: Callable[[str], Any],
        extensions: Optional[Tuple[str, ...]] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
Philip Meier's avatar
Philip Meier committed
142
    ) -> None:
143
        super().__init__(root, transform=transform, target_transform=target_transform)
144
        classes, class_to_idx = self.find_classes(self.root)
145
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
146
147
148
149
150
151
152

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
153
        self.targets = [s[1] for s in samples]
154

155
156
157
158
159
160
161
    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Dict[str, int],
        extensions: Optional[Tuple[str, ...]] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> List[Tuple[str, int]]:
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        """Generates a list of samples of a form (path_to_sample, class).

        This can be overridden to e.g. read files from a compressed zip file instead of from the disk.

        Args:
            directory (str): root dataset directory, corresponding to ``self.root``.
            class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
            extensions (optional): A list of allowed extensions.
                Either extensions or is_valid_file should be passed. Defaults to None.
            is_valid_file (optional): A function that takes path of a file
                and checks if the file is a valid file
                (used to check of corrupt files) both extensions and
                is_valid_file should not be passed. Defaults to None.

        Raises:
            ValueError: In case ``class_to_idx`` is empty.
            ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
            FileNotFoundError: In case no valid file was found for any class.

        Returns:
            List[Tuple[str, int]]: samples of a form (path_to_sample, class)
        """
184
185
186
187
        if class_to_idx is None:
            # prevent potential bug since make_dataset() would use the class_to_idx logic of the
            # find_classes() function, instead of using that of the find_classes() method, which
            # is potentially overridden and thus could have a different logic.
188
            raise ValueError("The class_to_idx parameter cannot be None.")
189
190
        return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)

191
192
193
194
195
196
197
198
199
200
201
202
203
204
    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
        """Find the class folders in a dataset structured as follows::

            directory/
            ├── class_x
            │   ├── xxx.ext
            │   ├── xxy.ext
            │   └── ...
            │       └── xxz.ext
            └── class_y
                ├── 123.ext
                ├── nsdf3.ext
                └── ...
                └── asd932_.ext
205
206
207

        This method can be overridden to only consider
        a subset of classes, or to adapt to a different dataset directory structure.
208
209
210
211
212
213
214
215
216

        Args:
            directory(str): Root directory path, corresponding to ``self.root``

        Raises:
            FileNotFoundError: If ``dir`` has no class folders.

        Returns:
            (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
217
        """
218
        return find_classes(directory)
219

Philip Meier's avatar
Philip Meier committed
220
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

Philip Meier's avatar
Philip Meier committed
237
    def __len__(self) -> int:
238
239
240
        return len(self.samples)


241
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
242
243


Philip Meier's avatar
Philip Meier committed
244
def pil_loader(path: str) -> Image.Image:
245
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
246
    with open(path, "rb") as f:
247
        img = Image.open(f)
248
        return img.convert("RGB")
249
250


Philip Meier's avatar
Philip Meier committed
251
252
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
253
    import accimage
254

255
256
    try:
        return accimage.Image(path)
257
    except OSError:
258
259
260
261
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


Philip Meier's avatar
Philip Meier committed
262
def default_loader(path: str) -> Any:
263
    from torchvision import get_image_backend
264
265

    if get_image_backend() == "accimage":
266
267
268
269
270
        return accimage_loader(path)
    else:
        return pil_loader(path)


271
class ImageFolder(DatasetFolder):
272
    """A generic data loader where the images are arranged in this way by default: ::
273
274
275

        root/dog/xxx.png
        root/dog/xxy.png
276
        root/dog/[...]/xxz.png
277
278
279

        root/cat/123.png
        root/cat/nsdf3.png
280
        root/cat/[...]/asd932_.png
281

282
283
284
    This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
    the same methods can be overridden to customize the dataset.

285
286
287
288
289
290
291
    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
292
        is_valid_file (callable, optional): A function that takes path of an Image file
Carrie's avatar
Carrie committed
293
            and check if the file is a valid file (used to check of corrupt files)
294
295

     Attributes:
296
        classes (list): List of the class names sorted alphabetically.
297
298
299
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """
300

Philip Meier's avatar
Philip Meier committed
301
    def __init__(
302
303
304
305
306
307
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        loader: Callable[[str], Any] = default_loader,
        is_valid_file: Optional[Callable[[str], bool]] = None,
Philip Meier's avatar
Philip Meier committed
308
    ):
309
        super().__init__(
310
311
312
313
314
315
316
            root,
            loader,
            IMG_EXTENSIONS if is_valid_file is None else None,
            transform=transform,
            target_transform=target_transform,
            is_valid_file=is_valid_file,
        )
317
        self.imgs = self.samples