folder.py 1.99 KB
Newer Older
soumith's avatar
soumith committed
1
2
3
4
5
6
7
8
9
10
11
import torch.utils.data as data

from PIL import Image
import os
import os.path

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

12

soumith's avatar
soumith committed
13
14
15
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

16

soumith's avatar
soumith committed
17
18
19
20
21
22
def find_classes(dir):
    classes = os.listdir(dir)
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

23

soumith's avatar
soumith committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def make_dataset(dir, class_to_idx):
    images = []
    for target in os.listdir(dir):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

        for filename in os.listdir(d):
            if is_image_file(filename):
                path = '{0}/{1}'.format(target, filename)
                item = (path, class_to_idx[target])
                images.append(item)

    return images

39
40
41
42
43

def default_loader(path):
    return Image.open(path).convert('RGB')


soumith's avatar
soumith committed
44
class ImageFolder(data.Dataset):
45

46
47
    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader):
soumith's avatar
soumith committed
48
49
        classes, class_to_idx = find_classes(root)
        imgs = make_dataset(root, class_to_idx)
50
51
52
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
soumith's avatar
soumith committed
53
54
55
56
57
58
59

        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
60
        self.loader = loader
soumith's avatar
soumith committed
61
62
63

    def __getitem__(self, index):
        path, target = self.imgs[index]
64
        img = self.loader(os.path.join(self.root, path))
soumith's avatar
soumith committed
65
66
67
68
69
70
71
72
73
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.imgs)