folder.py 11.6 KB
Newer Older
soumith's avatar
soumith committed
1
2
import os
import os.path
3
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
soumith's avatar
soumith committed
4

5
6
7
8
from PIL import Image

from .vision import VisionDataset

9

10
def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
11
    """Checks if a file is an allowed extension.
12
13
14

    Args:
        filename (string): path to a file
15
        extensions (tuple of strings): extensions to consider (lowercase)
16
17

    Returns:
18
        bool: True if the filename ends with one of given extensions
19
    """
20
    return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
soumith's avatar
soumith committed
21

22

Philip Meier's avatar
Philip Meier committed
23
def is_image_file(filename: str) -> bool:
24
25
26
27
28
29
30
31
32
33
34
    """Checks if a file is an allowed image extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    return has_file_allowed_extension(filename, IMG_EXTENSIONS)


35
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
36
    """Finds the class folders in a dataset.
37

38
    See :class:`DatasetFolder` for details.
39
40
41
42
43
44
45
46
47
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx


Philip Meier's avatar
Philip Meier committed
48
49
def make_dataset(
    directory: str,
50
    class_to_idx: Optional[Dict[str, int]] = None,
51
    extensions: Optional[Union[str, Tuple[str, ...]]] = None,
Philip Meier's avatar
Philip Meier committed
52
53
    is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
54
55
    """Generates a list of samples of a form (path_to_sample, class).

56
    See :class:`DatasetFolder` for details.
57

58
59
    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
    by default.
60
    """
61
    directory = os.path.expanduser(directory)
62
63
64
65
66
67

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

68
69
70
    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
Surgan Jandial's avatar
Surgan Jandial committed
71
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
72

73
    if extensions is not None:
74

Philip Meier's avatar
Philip Meier committed
75
        def is_valid_file(x: str) -> bool:
76
            return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]
77

Philip Meier's avatar
Philip Meier committed
78
    is_valid_file = cast(Callable[[str], bool], is_valid_file)
79
80
81

    instances = []
    available_classes = set()
82
83
84
85
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
soumith's avatar
soumith committed
86
            continue
87
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
88
            for fname in sorted(fnames):
89
90
                path = os.path.join(root, fname)
                if is_valid_file(path):
91
92
                    item = path, class_index
                    instances.append(item)
93
94
95
96

                    if target_class not in available_classes:
                        available_classes.add(target_class)

97
    empty_classes = set(class_to_idx.keys()) - available_classes
98
99
100
    if empty_classes:
        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        if extensions is not None:
101
            msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
102
103
        raise FileNotFoundError(msg)

104
    return instances
soumith's avatar
soumith committed
105

106

107
class DatasetFolder(VisionDataset):
108
    """A generic data loader.
109

110
111
    This default directory structure can be customized by overriding the
    :meth:`find_classes` method.
112
113
114
115

    Args:
        root (string): Root directory path.
        loader (callable): A function to load a sample given its path.
116
        extensions (tuple[string]): A list of allowed extensions.
117
            both extensions and is_valid_file should not be passed.
118
119
120
121
122
        transform (callable, optional): A function/transform that takes in
            a sample and returns a transformed version.
            E.g, ``transforms.RandomCrop`` for images.
        target_transform (callable, optional): A function/transform that takes
            in the target and transforms it.
Carrie's avatar
Carrie committed
123
124
        is_valid_file (callable, optional): A function that takes path of a file
            and check if the file is a valid file (used to check of corrupt files)
125
            both extensions and is_valid_file should not be passed.
126
127

     Attributes:
128
        classes (list): List of the class names sorted alphabetically.
129
130
        class_to_idx (dict): Dict with items (class_name, class_index).
        samples (list): List of (sample path, class_index) tuples
131
        targets (list): The class_index value for each image in the dataset
132
133
    """

Philip Meier's avatar
Philip Meier committed
134
    def __init__(
135
136
137
138
139
140
141
        self,
        root: str,
        loader: Callable[[str], Any],
        extensions: Optional[Tuple[str, ...]] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
Philip Meier's avatar
Philip Meier committed
142
    ) -> None:
143
        super().__init__(root, transform=transform, target_transform=target_transform)
144
        classes, class_to_idx = self.find_classes(self.root)
145
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
146
147
148
149
150
151
152

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
153
        self.targets = [s[1] for s in samples]
154

155
156
157
158
159
160
161
    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Dict[str, int],
        extensions: Optional[Tuple[str, ...]] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> List[Tuple[str, int]]:
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        """Generates a list of samples of a form (path_to_sample, class).

        This can be overridden to e.g. read files from a compressed zip file instead of from the disk.

        Args:
            directory (str): root dataset directory, corresponding to ``self.root``.
            class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
            extensions (optional): A list of allowed extensions.
                Either extensions or is_valid_file should be passed. Defaults to None.
            is_valid_file (optional): A function that takes path of a file
                and checks if the file is a valid file
                (used to check of corrupt files) both extensions and
                is_valid_file should not be passed. Defaults to None.

        Raises:
            ValueError: In case ``class_to_idx`` is empty.
            ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
            FileNotFoundError: In case no valid file was found for any class.

        Returns:
            List[Tuple[str, int]]: samples of a form (path_to_sample, class)
        """
184
185
186
187
        if class_to_idx is None:
            # prevent potential bug since make_dataset() would use the class_to_idx logic of the
            # find_classes() function, instead of using that of the find_classes() method, which
            # is potentially overridden and thus could have a different logic.
188
            raise ValueError("The class_to_idx parameter cannot be None.")
189
190
        return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)

191
192
193
194
195
196
197
198
199
200
201
202
203
204
    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
        """Find the class folders in a dataset structured as follows::

            directory/
            ├── class_x
            │   ├── xxx.ext
            │   ├── xxy.ext
            │   └── ...
            │       └── xxz.ext
            └── class_y
                ├── 123.ext
                ├── nsdf3.ext
                └── ...
                └── asd932_.ext
205
206
207

        This method can be overridden to only consider
        a subset of classes, or to adapt to a different dataset directory structure.
208
209
210
211
212
213
214
215
216

        Args:
            directory(str): Root directory path, corresponding to ``self.root``

        Raises:
            FileNotFoundError: If ``dir`` has no class folders.

        Returns:
            (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
217
        """
218
        return find_classes(directory)
219

Philip Meier's avatar
Philip Meier committed
220
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

Philip Meier's avatar
Philip Meier committed
237
    def __len__(self) -> int:
238
239
240
        return len(self.samples)


241
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
242
243


Philip Meier's avatar
Philip Meier committed
244
def pil_loader(path: str) -> Image.Image:
245
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
246
    with open(path, "rb") as f:
247
        img = Image.open(f)
248
        return img.convert("RGB")
249
250


Philip Meier's avatar
Philip Meier committed
251
252
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
253
    import accimage
254

255
256
    try:
        return accimage.Image(path)
257
    except OSError:
258
259
260
261
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


Philip Meier's avatar
Philip Meier committed
262
def default_loader(path: str) -> Any:
263
    from torchvision import get_image_backend
264
265

    if get_image_backend() == "accimage":
266
267
268
269
270
        return accimage_loader(path)
    else:
        return pil_loader(path)


271
class ImageFolder(DatasetFolder):
272
    """A generic data loader where the images are arranged in this way by default: ::
273
274
275

        root/dog/xxx.png
        root/dog/xxy.png
276
        root/dog/[...]/xxz.png
277
278
279

        root/cat/123.png
        root/cat/nsdf3.png
280
        root/cat/[...]/asd932_.png
281

282
283
284
    This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
    the same methods can be overridden to customize the dataset.

285
286
    Args:
        root (string): Root directory path.
anthony-cabacungan's avatar
anthony-cabacungan committed
287
        transform (callable, optional): A function/transform that takes in a PIL image
288
289
290
291
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
292
        is_valid_file (callable, optional): A function that takes path of an Image file
Carrie's avatar
Carrie committed
293
            and check if the file is a valid file (used to check of corrupt files)
294
295

     Attributes:
296
        classes (list): List of the class names sorted alphabetically.
297
298
299
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """
300

Philip Meier's avatar
Philip Meier committed
301
    def __init__(
302
303
304
305
306
307
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        loader: Callable[[str], Any] = default_loader,
        is_valid_file: Optional[Callable[[str], bool]] = None,
Philip Meier's avatar
Philip Meier committed
308
    ):
309
        super().__init__(
310
311
312
313
314
315
316
            root,
            loader,
            IMG_EXTENSIONS if is_valid_file is None else None,
            transform=transform,
            target_transform=target_transform,
            is_valid_file=is_valid_file,
        )
317
        self.imgs = self.samples