folder.py 3.85 KB
Newer Older
soumith's avatar
soumith committed
1
2
3
4
5
6
import torch.utils.data as data

from PIL import Image
import os
import os.path

7
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
soumith's avatar
soumith committed
8

9

soumith's avatar
soumith committed
10
def is_image_file(filename):
11
12
13
14
15
16
17
18
19
20
    """Checks if a file is an image.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)
soumith's avatar
soumith committed
21

22

soumith's avatar
soumith committed
23
def find_classes(dir):
NC Cullen's avatar
NC Cullen committed
24
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
soumith's avatar
soumith committed
25
26
27
28
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

29

soumith's avatar
soumith committed
30
31
def make_dataset(dir, class_to_idx):
    images = []
32
    dir = os.path.expanduser(dir)
33
    for target in sorted(os.listdir(dir)):
soumith's avatar
soumith committed
34
35
36
37
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

NC Cullen's avatar
NC Cullen committed
38
        for root, _, fnames in sorted(os.walk(d)):
39
            for fname in sorted(fnames):
NC Cullen's avatar
NC Cullen committed
40
41
42
43
                if is_image_file(fname):
                    path = os.path.join(root, fname)
                    item = (path, class_to_idx[target])
                    images.append(item)
soumith's avatar
soumith committed
44
45
46

    return images

47

48
def pil_loader(path):
49
50
51
52
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')
53
54


55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


soumith's avatar
soumith committed
72
class ImageFolder(data.Dataset):
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    """A generic data loader where the images are arranged in this way: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """
96

97
98
    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader):
soumith's avatar
soumith committed
99
100
        classes, class_to_idx = find_classes(root)
        imgs = make_dataset(root, class_to_idx)
101
102
103
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
soumith's avatar
soumith committed
104
105
106
107
108
109
110

        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
111
        self.loader = loader
soumith's avatar
soumith committed
112
113

    def __getitem__(self, index):
114
115
116
117
118
119
120
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
soumith's avatar
soumith committed
121
        path, target = self.imgs[index]
122
        img = self.loader(path)
soumith's avatar
soumith committed
123
124
125
126
127
128
129
130
131
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.imgs)