Commit 90091b5b authored by Dhananjay's avatar Dhananjay Committed by Francisco Massa
Browse files

added is_valid_file parameter to datasets folders (#867)

* added is_valid_file option

* small fixes

* fixes

* flake8 fixes

* some test

* flake8 fixes

* improvements

* modifications on tests
* fixes

* minor fix
parent 2c4e68e4
......@@ -36,6 +36,20 @@ class Tester(unittest.TestCase):
outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)
dataset = ImageFolder(Tester.root, loader=lambda x: x, is_valid_file=lambda x: '3' in x)
self.assertEqual(sorted(Tester.classes), sorted(dataset.classes))
for cls in Tester.classes:
self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_path, class_a_idx)for img_path in Tester.class_a_images if '3' in img_path]
imgs_b = [(img_path, class_b_idx)for img_path in Tester.class_b_images if '3' in img_path]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)
outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)
def test_transform(self):
return_value = get_file_path_2('test/assets/dataset/a/a1.png')
......
......@@ -32,18 +32,24 @@ def is_image_file(filename):
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def make_dataset(dir, class_to_idx, extensions):
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
images = []
dir = os.path.expanduser(dir)
if extensions is None and is_valid_file is None:
raise ValueError("Both extensions and is_valid_file cannot be None")
if extensions is not None and is_valid_file is not None:
raise ValueError("One of the extensions and is_valid_file should be None")
if extensions is not None:
def is_valid_file(x):
return has_file_allowed_extension(x, extensions)
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):
path = os.path.join(root, fname)
if is_valid_file(path):
item = (path, class_to_idx[target])
images.append(item)
......@@ -65,11 +71,15 @@ class DatasetFolder(VisionDataset):
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid_file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names.
......@@ -78,12 +88,12 @@ class DatasetFolder(VisionDataset):
targets (list): The class_index value for each image in the dataset
"""
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None):
super(DatasetFolder, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
classes, class_to_idx = self._find_classes(self.root)
samples = make_dataset(self.root, class_to_idx, extensions)
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
"Supported extensions are: " + ",".join(extensions)))
......@@ -184,6 +194,8 @@ class ImageFolder(DatasetFolder):
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid_file (used to check of corrupt files)
Attributes:
classes (list): List of the class names.
......@@ -192,8 +204,9 @@ class ImageFolder(DatasetFolder):
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
loader=default_loader, is_valid_file=None):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform)
target_transform=target_transform,
is_valid_file=is_valid_file)
self.imgs = self.samples
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment