folder.py 9.42 KB
Newer Older
1
from .vision import VisionDataset
soumith's avatar
soumith committed
2
3

from PIL import Image
4

soumith's avatar
soumith committed
5
6
import os
import os.path
Philip Meier's avatar
Philip Meier committed
7
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
soumith's avatar
soumith committed
8

9

Philip Meier's avatar
Philip Meier committed
10
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
11
    """Checks if a file is an allowed extension.
12
13
14

    Args:
        filename (string): path to a file
15
        extensions (tuple of strings): extensions to consider (lowercase)
16
17

    Returns:
18
        bool: True if the filename ends with one of given extensions
19
    """
20
    return filename.lower().endswith(extensions)
soumith's avatar
soumith committed
21

22

Philip Meier's avatar
Philip Meier committed
23
def is_image_file(filename: str) -> bool:
24
25
26
27
28
29
30
31
32
33
34
    """Checks if a file is an allowed image extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    return has_file_allowed_extension(filename, IMG_EXTENSIONS)


Philip Meier's avatar
Philip Meier committed
35
36
37
38
39
40
def make_dataset(
    directory: str,
    class_to_idx: Dict[str, int],
    extensions: Optional[Tuple[str, ...]] = None,
    is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    """Generates a list of samples of a form (path_to_sample, class).

    Args:
        directory (str): root dataset directory
        class_to_idx (Dict[str, int]): dictionary mapping class name to class index
        extensions (optional): A list of allowed extensions.
            Either extensions or is_valid_file should be passed. Defaults to None.
        is_valid_file (optional): A function that takes path of a file
            and checks if the file is a valid file
            (used to check of corrupt files) both extensions and
            is_valid_file should not be passed. Defaults to None.

    Raises:
        ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.

    Returns:
        List[Tuple[str, int]]: samples of a form (path_to_sample, class)
    """
59
60
61
62
63
    instances = []
    directory = os.path.expanduser(directory)
    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
Surgan Jandial's avatar
Surgan Jandial committed
64
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
65
    if extensions is not None:
Philip Meier's avatar
Philip Meier committed
66
67
68
        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
    is_valid_file = cast(Callable[[str], bool], is_valid_file)
69
70
71
72
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
soumith's avatar
soumith committed
73
            continue
74
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
75
            for fname in sorted(fnames):
76
77
                path = os.path.join(root, fname)
                if is_valid_file(path):
78
79
80
                    item = path, class_index
                    instances.append(item)
    return instances
soumith's avatar
soumith committed
81

82

83
class DatasetFolder(VisionDataset):
84
85
86
87
    """A generic data loader where the samples are arranged in this way: ::

        root/class_x/xxx.ext
        root/class_x/xxy.ext
88
        root/class_x/[...]/xxz.ext
89
90
91

        root/class_y/123.ext
        root/class_y/nsdf3.ext
92
        root/class_y/[...]/asd932_.ext
93
94
95
96

    Args:
        root (string): Root directory path.
        loader (callable): A function to load a sample given its path.
97
        extensions (tuple[string]): A list of allowed extensions.
98
            both extensions and is_valid_file should not be passed.
99
100
101
102
103
        transform (callable, optional): A function/transform that takes in
            a sample and returns a transformed version.
            E.g, ``transforms.RandomCrop`` for images.
        target_transform (callable, optional): A function/transform that takes
            in the target and transforms it.
Carrie's avatar
Carrie committed
104
105
        is_valid_file (callable, optional): A function that takes path of a file
            and check if the file is a valid file (used to check of corrupt files)
106
            both extensions and is_valid_file should not be passed.
107
108

     Attributes:
109
        classes (list): List of the class names sorted alphabetically.
110
111
        class_to_idx (dict): Dict with items (class_name, class_index).
        samples (list): List of (sample path, class_index) tuples
112
        targets (list): The class_index value for each image in the dataset
113
114
    """

Philip Meier's avatar
Philip Meier committed
115
116
117
118
119
120
121
122
123
    def __init__(
            self,
            root: str,
            loader: Callable[[str], Any],
            extensions: Optional[Tuple[str, ...]] = None,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> None:
124
125
        super(DatasetFolder, self).__init__(root, transform=transform,
                                            target_transform=target_transform)
126
        classes, class_to_idx = self._find_classes(self.root)
127
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
128
        if len(samples) == 0:
129
130
131
132
            msg = "Found 0 files in subfolders of: {}\n".format(self.root)
            if extensions is not None:
                msg += "Supported extensions are: {}".format(",".join(extensions))
            raise RuntimeError(msg)
133
134
135
136
137
138
139

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
140
        self.targets = [s[1] for s in samples]
141

142
143
144
145
146
147
148
149
150
    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Dict[str, int],
        extensions: Optional[Tuple[str, ...]] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> List[Tuple[str, int]]:
        return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)

Philip Meier's avatar
Philip Meier committed
151
    def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
152
153
154
155
156
157
158
159
160
161
162
163
        """
        Finds the class folders in a dataset.

        Args:
            dir (string): Root directory path.

        Returns:
            tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.

        Ensures:
            No class is a subdirectory of another.
        """
164
        classes = [d.name for d in os.scandir(dir) if d.is_dir()]
165
        classes.sort()
166
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
167
        return classes, class_to_idx
168

Philip Meier's avatar
Philip Meier committed
169
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

Philip Meier's avatar
Philip Meier committed
186
    def __len__(self) -> int:
187
188
189
        return len(self.samples)


190
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
191
192


Philip Meier's avatar
Philip Meier committed
193
def pil_loader(path: str) -> Image.Image:
194
195
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
196
197
        img = Image.open(f)
        return img.convert('RGB')
198
199


Philip Meier's avatar
Philip Meier committed
200
201
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
202
203
204
205
206
207
208
209
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


Philip Meier's avatar
Philip Meier committed
210
def default_loader(path: str) -> Any:
211
212
213
214
215
216
217
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


218
class ImageFolder(DatasetFolder):
219
220
221
222
    """A generic data loader where the images are arranged in this way: ::

        root/dog/xxx.png
        root/dog/xxy.png
223
        root/dog/[...]/xxz.png
224
225
226

        root/cat/123.png
        root/cat/nsdf3.png
227
        root/cat/[...]/asd932_.png
228
229
230
231
232
233
234
235

    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
236
        is_valid_file (callable, optional): A function that takes path of an Image file
Carrie's avatar
Carrie committed
237
            and check if the file is a valid file (used to check of corrupt files)
238
239

     Attributes:
240
        classes (list): List of the class names sorted alphabetically.
241
242
243
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """
244

Philip Meier's avatar
Philip Meier committed
245
246
247
248
249
250
251
252
    def __init__(
            self,
            root: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            loader: Callable[[str], Any] = default_loader,
            is_valid_file: Optional[Callable[[str], bool]] = None,
    ):
253
        super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
254
                                          transform=transform,
255
256
                                          target_transform=target_transform,
                                          is_valid_file=is_valid_file)
257
        self.imgs = self.samples