Commit ab03dc43 authored by Frédérik Paradis's avatar Frédérik Paradis Committed by Soumith Chintala
Browse files

Adding a DatasetFolder class. (#442)

* Adding tests to ImageFolder

* Adding DatasetFolder class

* Fix tests for pytest and code for lint checker

* Adding mock to requirements for ImageFolder tests
parent 456d3b97
...@@ -4,11 +4,11 @@ torchvision.datasets ...@@ -4,11 +4,11 @@ torchvision.datasets
All datasets are subclasses of :class:`torch.utils.data.Dataset` All datasets are subclasses of :class:`torch.utils.data.Dataset`
i.e, they have ``__getitem__`` and ``__len__`` methods implemented. i.e, they have ``__getitem__`` and ``__len__`` methods implemented.
Hence, they can all be passed to a :class:`torch.utils.data.DataLoader` Hence, they can all be passed to a :class:`torch.utils.data.DataLoader`
which can load multiple samples parallelly using ``torch.multiprocessing`` workers. which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
For example: :: For example: ::
imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/') imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data, data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4, batch_size=4,
shuffle=True, shuffle=True,
num_workers=args.nThreads) num_workers=args.nThreads)
...@@ -22,7 +22,7 @@ All the datasets have almost similar API. They all have two common arguments: ...@@ -22,7 +22,7 @@ All the datasets have almost similar API. They all have two common arguments:
``transform`` and ``target_transform`` to transform the input and target respectively. ``transform`` and ``target_transform`` to transform the input and target respectively.
.. currentmodule:: torchvision.datasets .. currentmodule:: torchvision.datasets
MNIST MNIST
...@@ -78,6 +78,14 @@ ImageFolder ...@@ -78,6 +78,14 @@ ImageFolder
:members: __getitem__ :members: __getitem__
:special-members: :special-members:
DatasetFolder
~~~~~~~~~~~~~
.. autoclass:: DatasetFolder
:members: __getitem__
:special-members:
Imagenet-12 Imagenet-12
~~~~~~~~~~~ ~~~~~~~~~~~
...@@ -121,4 +129,3 @@ PhotoTour ...@@ -121,4 +129,3 @@ PhotoTour
.. autoclass:: PhotoTour .. autoclass:: PhotoTour
:members: __getitem__ :members: __getitem__
:special-members: :special-members:
...@@ -33,6 +33,7 @@ requirements = [ ...@@ -33,6 +33,7 @@ requirements = [
'pillow >= 4.1.1', 'pillow >= 4.1.1',
'six', 'six',
'torch', 'torch',
'mock',
] ]
setup( setup(
......
import unittest
try:
from unittest.mock import Mock
except ImportError as e:
from mock import Mock
import os
from torchvision.datasets import ImageFolder
class Tester(unittest.TestCase):
root = 'test/assets/dataset/'
classes = ['a', 'b']
class_a_images = [os.path.join('test/assets/dataset/a/', path) for path in ['a1.png', 'a2.png', 'a3.png']]
class_b_images = [os.path.join('test/assets/dataset/b/', path) for path in ['b1.png', 'b2.png', 'b3.png', 'b4.png']]
def test_image_folder(self):
dataset = ImageFolder(Tester.root, loader=lambda x: x)
self.assertEqual(sorted(Tester.classes), sorted(dataset.classes))
for cls in Tester.classes:
self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_path, class_a_idx)for img_path in Tester.class_a_images]
imgs_b = [(img_path, class_b_idx)for img_path in Tester.class_b_images]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)
outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)
def test_transform(self):
return_value = 'test/assets/dataset/a/a1.png'
transform = Mock(return_value=return_value)
dataset = ImageFolder(Tester.root, loader=lambda x: x, transform=transform)
outputs = [dataset[i][0] for i in range(len(dataset))]
self.assertEqual([return_value] * len(outputs), outputs)
imgs = sorted(Tester.class_a_images + Tester.class_b_images)
args = [call[0][0] for call in transform.call_args_list]
self.assertEqual(imgs, sorted(args))
def test_target_transform(self):
return_value = 1
target_transform = Mock(return_value=return_value)
dataset = ImageFolder(Tester.root, loader=lambda x: x, target_transform=target_transform)
outputs = [dataset[i][1] for i in range(len(dataset))]
self.assertEqual([return_value] * len(outputs), outputs)
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
targets = sorted([class_a_idx] * len(Tester.class_a_images) +
[class_b_idx] * len(Tester.class_b_images))
args = [call[0][0] for call in target_transform.call_args_list]
self.assertEqual(targets, sorted(args))
if __name__ == '__main__':
unittest.main()
from .lsun import LSUN, LSUNClass from .lsun import LSUN, LSUNClass
from .folder import ImageFolder from .folder import ImageFolder, DatasetFolder
from .coco import CocoCaptions, CocoDetection from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10 from .stl10 import STL10
...@@ -11,7 +11,7 @@ from .semeion import SEMEION ...@@ -11,7 +11,7 @@ from .semeion import SEMEION
from .omniglot import Omniglot from .omniglot import Omniglot
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'FakeData', 'ImageFolder', 'DatasetFolder', 'FakeData',
'CocoCaptions', 'CocoDetection', 'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
......
import torch.utils.data as data import torch.utils.data as data
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
def is_image_file(filename): def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an image. """Checks if a file is an allowed extension.
Args: Args:
filename (string): path to a file filename (string): path to a file
...@@ -17,7 +16,7 @@ def is_image_file(filename): ...@@ -17,7 +16,7 @@ def is_image_file(filename):
bool: True if the filename ends with a known image extension bool: True if the filename ends with a known image extension
""" """
filename_lower = filename.lower() filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) return any(filename_lower.endswith(ext) for ext in extensions)
def find_classes(dir): def find_classes(dir):
...@@ -27,7 +26,7 @@ def find_classes(dir): ...@@ -27,7 +26,7 @@ def find_classes(dir):
return classes, class_to_idx return classes, class_to_idx
def make_dataset(dir, class_to_idx): def make_dataset(dir, class_to_idx, extensions):
images = [] images = []
dir = os.path.expanduser(dir) dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)): for target in sorted(os.listdir(dir)):
...@@ -37,7 +36,7 @@ def make_dataset(dir, class_to_idx): ...@@ -37,7 +36,7 @@ def make_dataset(dir, class_to_idx):
for root, _, fnames in sorted(os.walk(d)): for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames): for fname in sorted(fnames):
if is_image_file(fname): if has_file_allowed_extension(fname, extensions):
path = os.path.join(root, fname) path = os.path.join(root, fname)
item = (path, class_to_idx[target]) item = (path, class_to_idx[target])
images.append(item) images.append(item)
...@@ -45,6 +44,85 @@ def make_dataset(dir, class_to_idx): ...@@ -45,6 +44,85 @@ def make_dataset(dir, class_to_idx):
return images return images
class DatasetFolder(data.Dataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (list[string]): A list of allowed extensions.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
"""
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
classes, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
if len(samples) == 0:
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.root = root
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
def pil_loader(path): def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f: with open(path, 'rb') as f:
...@@ -69,7 +147,7 @@ def default_loader(path): ...@@ -69,7 +147,7 @@ def default_loader(path):
return pil_loader(path) return pil_loader(path)
class ImageFolder(data.Dataset): class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: :: """A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png root/dog/xxx.png
...@@ -93,49 +171,9 @@ class ImageFolder(data.Dataset): ...@@ -93,49 +171,9 @@ class ImageFolder(data.Dataset):
class_to_idx (dict): Dict with items (class_name, class_index). class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples imgs (list): List of (image path, class_index) tuples
""" """
def __init__(self, root, transform=None, target_transform=None, def __init__(self, root, transform=None, target_transform=None,
loader=default_loader): loader=default_loader):
classes, class_to_idx = find_classes(root) super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
imgs = make_dataset(root, class_to_idx) transform=transform,
if len(imgs) == 0: target_transform=target_transform)
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" self.imgs = self.samples
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.imgs)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment