Unverified Commit 7c052cea authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub
Browse files

Revert "Adding a DatasetFolder class. (#442)" (#443)

This reverts commit ab03dc43.
parent ab03dc43
......@@ -4,11 +4,11 @@ torchvision.datasets
All datasets are subclasses of :class:`torch.utils.data.Dataset`
i.e, they have ``__getitem__`` and ``__len__`` methods implemented.
Hence, they can all be passed to a :class:`torch.utils.data.DataLoader`
which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
For example: ::
imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
......@@ -22,7 +22,7 @@ All the datasets have almost similar API. They all have two common arguments:
``transform`` and ``target_transform`` to transform the input and target respectively.
.. currentmodule:: torchvision.datasets
.. currentmodule:: torchvision.datasets
MNIST
......@@ -78,14 +78,6 @@ ImageFolder
:members: __getitem__
:special-members:
DatasetFolder
~~~~~~~~~~~~~
.. autoclass:: DatasetFolder
:members: __getitem__
:special-members:
Imagenet-12
~~~~~~~~~~~
......@@ -129,3 +121,4 @@ PhotoTour
.. autoclass:: PhotoTour
:members: __getitem__
:special-members:
......@@ -33,7 +33,6 @@ requirements = [
'pillow >= 4.1.1',
'six',
'torch',
'mock',
]
setup(
......
import unittest
try:
from unittest.mock import Mock
except ImportError as e:
from mock import Mock
import os
from torchvision.datasets import ImageFolder
class Tester(unittest.TestCase):
root = 'test/assets/dataset/'
classes = ['a', 'b']
class_a_images = [os.path.join('test/assets/dataset/a/', path) for path in ['a1.png', 'a2.png', 'a3.png']]
class_b_images = [os.path.join('test/assets/dataset/b/', path) for path in ['b1.png', 'b2.png', 'b3.png', 'b4.png']]
def test_image_folder(self):
dataset = ImageFolder(Tester.root, loader=lambda x: x)
self.assertEqual(sorted(Tester.classes), sorted(dataset.classes))
for cls in Tester.classes:
self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_path, class_a_idx)for img_path in Tester.class_a_images]
imgs_b = [(img_path, class_b_idx)for img_path in Tester.class_b_images]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)
outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)
def test_transform(self):
return_value = 'test/assets/dataset/a/a1.png'
transform = Mock(return_value=return_value)
dataset = ImageFolder(Tester.root, loader=lambda x: x, transform=transform)
outputs = [dataset[i][0] for i in range(len(dataset))]
self.assertEqual([return_value] * len(outputs), outputs)
imgs = sorted(Tester.class_a_images + Tester.class_b_images)
args = [call[0][0] for call in transform.call_args_list]
self.assertEqual(imgs, sorted(args))
def test_target_transform(self):
return_value = 1
target_transform = Mock(return_value=return_value)
dataset = ImageFolder(Tester.root, loader=lambda x: x, target_transform=target_transform)
outputs = [dataset[i][1] for i in range(len(dataset))]
self.assertEqual([return_value] * len(outputs), outputs)
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
targets = sorted([class_a_idx] * len(Tester.class_a_images) +
[class_b_idx] * len(Tester.class_b_images))
args = [call[0][0] for call in target_transform.call_args_list]
self.assertEqual(targets, sorted(args))
if __name__ == '__main__':
unittest.main()
from .lsun import LSUN, LSUNClass
from .folder import ImageFolder, DatasetFolder
from .folder import ImageFolder
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
......@@ -11,7 +11,7 @@ from .semeion import SEMEION
from .omniglot import Omniglot
__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
'ImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
......
import torch.utils.data as data
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
def is_image_file(filename):
"""Checks if a file is an image.
Args:
filename (string): path to a file
......@@ -16,7 +17,7 @@ def has_file_allowed_extension(filename, extensions):
bool: True if the filename ends with a known image extension
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions)
return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)
def find_classes(dir):
......@@ -26,7 +27,7 @@ def find_classes(dir):
return classes, class_to_idx
def make_dataset(dir, class_to_idx, extensions):
def make_dataset(dir, class_to_idx):
images = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
......@@ -36,7 +37,7 @@ def make_dataset(dir, class_to_idx, extensions):
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):
if is_image_file(fname):
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item)
......@@ -44,85 +45,6 @@ def make_dataset(dir, class_to_idx, extensions):
return images
class DatasetFolder(data.Dataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (list[string]): A list of allowed extensions.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
"""
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
classes, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
if len(samples) == 0:
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.root = root
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
......@@ -147,7 +69,7 @@ def default_loader(path):
return pil_loader(path)
class ImageFolder(DatasetFolder):
class ImageFolder(data.Dataset):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
......@@ -171,9 +93,49 @@ class ImageFolder(DatasetFolder):
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
transform=transform,
target_transform=target_transform)
self.imgs = self.samples
classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.imgs)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment