"git@developer.sourcefind.cn:OpenDAS/torchani.git" did not exist on "c72ebd890b95b0037e86d94fb10b4402a722339e"
Commit 3f6c23c0 authored by Ernest Parke's avatar Ernest Parke Committed by Francisco Massa
Browse files

Addresses issue #145 as per @fmessa's suggestion. (#527)

* Addresses issue #145 as per @fmessa's suggestion.

* Removed blank line for styling.
parent 5a0d079c
...@@ -32,17 +32,10 @@ def is_image_file(filename): ...@@ -32,17 +32,10 @@ def is_image_file(filename):
return has_file_allowed_extension(filename, IMG_EXTENSIONS) return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_dataset(dir, class_to_idx, extensions): def make_dataset(dir, class_to_idx, extensions):
images = [] images = []
dir = os.path.expanduser(dir) dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)): for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target) d = os.path.join(dir, target)
if not os.path.isdir(d): if not os.path.isdir(d):
continue continue
...@@ -86,7 +79,7 @@ class DatasetFolder(data.Dataset): ...@@ -86,7 +79,7 @@ class DatasetFolder(data.Dataset):
""" """
def __init__(self, root, loader, extensions, transform=None, target_transform=None): def __init__(self, root, loader, extensions, transform=None, target_transform=None):
classes, class_to_idx = find_classes(root) classes, class_to_idx = self._find_classes(root)
samples = make_dataset(root, class_to_idx, extensions) samples = make_dataset(root, class_to_idx, extensions)
if len(samples) == 0: if len(samples) == 0:
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
...@@ -104,6 +97,24 @@ class DatasetFolder(data.Dataset): ...@@ -104,6 +97,24 @@ class DatasetFolder(data.Dataset):
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def __getitem__(self, index): def __getitem__(self, index):
""" """
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment