folder.py 12.6 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
import os
import os.path
from pathlib import Path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
soumith's avatar
soumith committed
5
6

from PIL import Image
7

limm's avatar
limm committed
8
from .vision import VisionDataset
soumith's avatar
soumith committed
9

10

limm's avatar
limm committed
11
def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
12
    """Checks if a file is an allowed extension.
13
14
15

    Args:
        filename (string): path to a file
16
        extensions (tuple of strings): extensions to consider (lowercase)
17
18

    Returns:
19
        bool: True if the filename ends with one of given extensions
20
    """
limm's avatar
limm committed
21
    return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
soumith's avatar
soumith committed
22

23

Philip Meier's avatar
Philip Meier committed
24
def is_image_file(filename: str) -> bool:
25
26
27
28
29
30
31
32
33
34
35
    """Checks if a file is an allowed image extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    return has_file_allowed_extension(filename, IMG_EXTENSIONS)


limm's avatar
limm committed
36
def find_classes(directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
37
    """Finds the class folders in a dataset.
38

39
    See :class:`DatasetFolder` for details.
40
41
42
43
44
45
46
47
48
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx


Philip Meier's avatar
Philip Meier committed
49
def make_dataset(
limm's avatar
limm committed
50
    directory: Union[str, Path],
51
    class_to_idx: Optional[Dict[str, int]] = None,
limm's avatar
limm committed
52
    extensions: Optional[Union[str, Tuple[str, ...]]] = None,
Philip Meier's avatar
Philip Meier committed
53
    is_valid_file: Optional[Callable[[str], bool]] = None,
limm's avatar
limm committed
54
    allow_empty: bool = False,
Philip Meier's avatar
Philip Meier committed
55
) -> List[Tuple[str, int]]:
56
57
    """Generates a list of samples of a form (path_to_sample, class).

58
    See :class:`DatasetFolder` for details.
59

60
61
    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
    by default.
62
    """
63
    directory = os.path.expanduser(directory)
64
65
66
67
68
69

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

70
71
72
    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
Surgan Jandial's avatar
Surgan Jandial committed
73
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
74

75
    if extensions is not None:
76

Philip Meier's avatar
Philip Meier committed
77
        def is_valid_file(x: str) -> bool:
limm's avatar
limm committed
78
            return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]
79

Philip Meier's avatar
Philip Meier committed
80
    is_valid_file = cast(Callable[[str], bool], is_valid_file)
81
82
83

    instances = []
    available_classes = set()
84
85
86
87
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
soumith's avatar
soumith committed
88
            continue
89
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
90
            for fname in sorted(fnames):
91
92
                path = os.path.join(root, fname)
                if is_valid_file(path):
93
94
                    item = path, class_index
                    instances.append(item)
95
96
97
98

                    if target_class not in available_classes:
                        available_classes.add(target_class)

99
    empty_classes = set(class_to_idx.keys()) - available_classes
limm's avatar
limm committed
100
    if empty_classes and not allow_empty:
101
102
        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        if extensions is not None:
limm's avatar
limm committed
103
            msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
104
105
        raise FileNotFoundError(msg)

106
    return instances
soumith's avatar
soumith committed
107

108

109
class DatasetFolder(VisionDataset):
110
    """A generic data loader.
111

112
113
    This default directory structure can be customized by overriding the
    :meth:`find_classes` method.
114
115

    Args:
limm's avatar
limm committed
116
        root (str or ``pathlib.Path``): Root directory path.
117
        loader (callable): A function to load a sample given its path.
118
        extensions (tuple[string]): A list of allowed extensions.
119
            both extensions and is_valid_file should not be passed.
120
121
122
123
124
        transform (callable, optional): A function/transform that takes in
            a sample and returns a transformed version.
            E.g, ``transforms.RandomCrop`` for images.
        target_transform (callable, optional): A function/transform that takes
            in the target and transforms it.
Carrie's avatar
Carrie committed
125
126
        is_valid_file (callable, optional): A function that takes path of a file
            and check if the file is a valid file (used to check of corrupt files)
127
            both extensions and is_valid_file should not be passed.
limm's avatar
limm committed
128
129
        allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
            An error is raised on empty folders if False (default).
130
131

     Attributes:
132
        classes (list): List of the class names sorted alphabetically.
133
134
        class_to_idx (dict): Dict with items (class_name, class_index).
        samples (list): List of (sample path, class_index) tuples
135
        targets (list): The class_index value for each image in the dataset
136
137
    """

Philip Meier's avatar
Philip Meier committed
138
    def __init__(
limm's avatar
limm committed
139
140
141
142
143
144
145
146
        self,
        root: Union[str, Path],
        loader: Callable[[str], Any],
        extensions: Optional[Tuple[str, ...]] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
        allow_empty: bool = False,
Philip Meier's avatar
Philip Meier committed
147
    ) -> None:
limm's avatar
limm committed
148
        super().__init__(root, transform=transform, target_transform=target_transform)
149
        classes, class_to_idx = self.find_classes(self.root)
limm's avatar
limm committed
150
151
152
153
154
155
156
        samples = self.make_dataset(
            self.root,
            class_to_idx=class_to_idx,
            extensions=extensions,
            is_valid_file=is_valid_file,
            allow_empty=allow_empty,
        )
157
158
159
160
161
162
163

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
164
        self.targets = [s[1] for s in samples]
165

166
167
    @staticmethod
    def make_dataset(
limm's avatar
limm committed
168
        directory: Union[str, Path],
169
170
171
        class_to_idx: Dict[str, int],
        extensions: Optional[Tuple[str, ...]] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
limm's avatar
limm committed
172
        allow_empty: bool = False,
173
    ) -> List[Tuple[str, int]]:
174
175
176
177
178
179
180
181
182
183
184
185
186
        """Generates a list of samples of a form (path_to_sample, class).

        This can be overridden to e.g. read files from a compressed zip file instead of from the disk.

        Args:
            directory (str): root dataset directory, corresponding to ``self.root``.
            class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
            extensions (optional): A list of allowed extensions.
                Either extensions or is_valid_file should be passed. Defaults to None.
            is_valid_file (optional): A function that takes path of a file
                and checks if the file is a valid file
                (used to check of corrupt files) both extensions and
                is_valid_file should not be passed. Defaults to None.
limm's avatar
limm committed
187
188
            allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
                An error is raised on empty folders if False (default).
189
190
191
192
193
194
195
196
197

        Raises:
            ValueError: In case ``class_to_idx`` is empty.
            ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
            FileNotFoundError: In case no valid file was found for any class.

        Returns:
            List[Tuple[str, int]]: samples of a form (path_to_sample, class)
        """
198
199
200
201
        if class_to_idx is None:
            # prevent potential bug since make_dataset() would use the class_to_idx logic of the
            # find_classes() function, instead of using that of the find_classes() method, which
            # is potentially overridden and thus could have a different logic.
limm's avatar
limm committed
202
203
204
205
            raise ValueError("The class_to_idx parameter cannot be None.")
        return make_dataset(
            directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file, allow_empty=allow_empty
        )
206

limm's avatar
limm committed
207
    def find_classes(self, directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
208
209
210
211
212
213
214
215
216
217
218
219
220
        """Find the class folders in a dataset structured as follows::

            directory/
            ├── class_x
            │   ├── xxx.ext
            │   ├── xxy.ext
            │   └── ...
            │       └── xxz.ext
            └── class_y
                ├── 123.ext
                ├── nsdf3.ext
                └── ...
                └── asd932_.ext
221
222
223

        This method can be overridden to only consider
        a subset of classes, or to adapt to a different dataset directory structure.
224
225
226
227
228
229
230
231
232

        Args:
            directory(str): Root directory path, corresponding to ``self.root``

        Raises:
            FileNotFoundError: If ``dir`` has no class folders.

        Returns:
            (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
233
        """
234
        return find_classes(directory)
235

Philip Meier's avatar
Philip Meier committed
236
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

Philip Meier's avatar
Philip Meier committed
253
    def __len__(self) -> int:
254
255
256
        return len(self.samples)


limm's avatar
limm committed
257
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
258
259


Philip Meier's avatar
Philip Meier committed
260
def pil_loader(path: str) -> Image.Image:
261
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
limm's avatar
limm committed
262
    with open(path, "rb") as f:
263
        img = Image.open(f)
limm's avatar
limm committed
264
        return img.convert("RGB")
265
266


Philip Meier's avatar
Philip Meier committed
267
268
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
269
    import accimage
limm's avatar
limm committed
270

271
272
    try:
        return accimage.Image(path)
limm's avatar
limm committed
273
    except OSError:
274
275
276
277
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


Philip Meier's avatar
Philip Meier committed
278
def default_loader(path: str) -> Any:
279
    from torchvision import get_image_backend
limm's avatar
limm committed
280
281

    if get_image_backend() == "accimage":
282
283
284
285
286
        return accimage_loader(path)
    else:
        return pil_loader(path)


287
class ImageFolder(DatasetFolder):
288
    """A generic data loader where the images are arranged in this way by default: ::
289
290
291

        root/dog/xxx.png
        root/dog/xxy.png
292
        root/dog/[...]/xxz.png
293
294
295

        root/cat/123.png
        root/cat/nsdf3.png
296
        root/cat/[...]/asd932_.png
297

298
299
300
    This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
    the same methods can be overridden to customize the dataset.

301
    Args:
limm's avatar
limm committed
302
303
        root (str or ``pathlib.Path``): Root directory path.
        transform (callable, optional): A function/transform that takes in a PIL image
304
305
306
307
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
308
        is_valid_file (callable, optional): A function that takes path of an Image file
Carrie's avatar
Carrie committed
309
            and check if the file is a valid file (used to check of corrupt files)
limm's avatar
limm committed
310
311
        allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
            An error is raised on empty folders if False (default).
312
313

     Attributes:
314
        classes (list): List of the class names sorted alphabetically.
315
316
317
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """
318

Philip Meier's avatar
Philip Meier committed
319
    def __init__(
limm's avatar
limm committed
320
321
322
323
324
325
326
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        loader: Callable[[str], Any] = default_loader,
        is_valid_file: Optional[Callable[[str], bool]] = None,
        allow_empty: bool = False,
Philip Meier's avatar
Philip Meier committed
327
    ):
limm's avatar
limm committed
328
329
330
331
332
333
334
335
336
        super().__init__(
            root,
            loader,
            IMG_EXTENSIONS if is_valid_file is None else None,
            transform=transform,
            target_transform=target_transform,
            is_valid_file=is_valid_file,
            allow_empty=allow_empty,
        )
337
        self.imgs = self.samples