Unverified Commit 0818c682 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Improve error handling in make_dataset (#3496)

* factor out find_classes

* use find_classes in video datasets

* adapt old tests
parent 19ad0bbc
......@@ -111,10 +111,10 @@ class Tester(DatasetTestcase):
def test_imagefolder_empty(self):
with get_tmp_dir() as root:
with self.assertRaises(RuntimeError):
with self.assertRaises(FileNotFoundError):
torchvision.datasets.ImageFolder(root, loader=lambda x: x)
with self.assertRaises(RuntimeError):
with self.assertRaises(FileNotFoundError):
torchvision.datasets.ImageFolder(
root, loader=lambda x: x, is_valid_file=lambda x: False
)
......@@ -1092,9 +1092,6 @@ class Kinetics400TestCase(datasets_utils.VideoDatasetTestCase):
return num_videos_per_class * len(classes)
def test_not_found_or_corrupted(self):
self.skipTest("Dataset currently does not handle the case of no found videos.")
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.HMDB51
......
......@@ -32,9 +32,43 @@ def is_image_file(filename: str) -> bool:
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset structured as follows:
.. code::
directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── ...
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── ...
└── asd932_.ext
Args:
directory (str): Root directory path.
Raises:
FileNotFoundError: If ``directory`` has no class folders.
Returns:
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
......@@ -42,7 +76,8 @@ def make_dataset(
Args:
directory (str): root dataset directory
class_to_idx (Dict[str, int]): dictionary mapping class name to class index
class_to_idx (Optional[Dict[str, int]]): Dictionary mapping class name to class index. If omitted, is generated
by :func:`find_classes`.
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
......@@ -51,21 +86,34 @@ def make_dataset(
is_valid_file should not be passed. Defaults to None.
Raises:
ValueError: In case ``class_to_idx`` is empty.
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
FileNotFoundError: In case no valid file was found for any class.
Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
instances = []
directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
......@@ -77,6 +125,17 @@ def make_dataset(
if is_valid_file(path):
item = path, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
empty_classes = available_classes - set(class_to_idx.keys())
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {', '.join(extensions)}"
raise FileNotFoundError(msg)
return instances
......@@ -125,11 +184,6 @@ class DatasetFolder(VisionDataset):
target_transform=target_transform)
classes, class_to_idx = self._find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0:
msg = "Found 0 files in subfolders of: {}\n".format(self.root)
if extensions is not None:
msg += "Supported extensions are: {}".format(",".join(extensions))
raise RuntimeError(msg)
self.loader = loader
self.extensions = extensions
......@@ -148,23 +202,9 @@ class DatasetFolder(VisionDataset):
) -> List[Tuple[str, int]]:
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
@staticmethod
def _find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]:
return find_classes(dir)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
......
......@@ -2,7 +2,7 @@ import glob
import os
from .utils import list_dir
from .folder import make_dataset
from .folder import find_classes, make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset
......@@ -62,8 +62,7 @@ class HMDB51(VisionDataset):
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
extensions = ('avi',)
classes = sorted(list_dir(root))
class_to_idx = {class_: i for (i, class_) in enumerate(classes)}
self.classes, class_to_idx = find_classes(self.root)
self.samples = make_dataset(
self.root,
class_to_idx,
......@@ -89,7 +88,6 @@ class HMDB51(VisionDataset):
self.full_video_clips = video_clips
self.fold = fold
self.train = train
self.classes = classes
self.indices = self._select_fold(video_paths, annotation_path, fold, train)
self.video_clips = video_clips.subset(self.indices)
self.transform = transform
......
from .utils import list_dir
from .folder import make_dataset
from .folder import find_classes, make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset
......@@ -56,10 +56,8 @@ class Kinetics400(VisionDataset):
_video_min_dimension=0, _audio_samples=0, _audio_channels=0):
super(Kinetics400, self).__init__(root)
classes = list(sorted(list_dir(root)))
class_to_idx = {classes[i]: i for i in range(len(classes))}
self.classes, class_to_idx = find_classes(self.root)
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
video_list = [x[0] for x in self.samples]
self.video_clips = VideoClips(
video_list,
......
import os
from .utils import list_dir
from .folder import make_dataset
from .folder import find_classes, make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset
......@@ -55,10 +55,8 @@ class UCF101(VisionDataset):
self.fold = fold
self.train = train
classes = list(sorted(list_dir(root)))
class_to_idx = {classes[i]: i for i in range(len(classes))}
self.classes, class_to_idx = find_classes(self.root)
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
video_list = [x[0] for x in self.samples]
video_clips = VideoClips(
video_list,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment