Unverified Commit 0818c682 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Improve error handling in make_dataset (#3496)

* factor out find_classes

* use find_classes in video datasets

* adapt old tests
parent 19ad0bbc
...@@ -111,10 +111,10 @@ class Tester(DatasetTestcase): ...@@ -111,10 +111,10 @@ class Tester(DatasetTestcase):
def test_imagefolder_empty(self): def test_imagefolder_empty(self):
with get_tmp_dir() as root: with get_tmp_dir() as root:
with self.assertRaises(RuntimeError): with self.assertRaises(FileNotFoundError):
torchvision.datasets.ImageFolder(root, loader=lambda x: x) torchvision.datasets.ImageFolder(root, loader=lambda x: x)
with self.assertRaises(RuntimeError): with self.assertRaises(FileNotFoundError):
torchvision.datasets.ImageFolder( torchvision.datasets.ImageFolder(
root, loader=lambda x: x, is_valid_file=lambda x: False root, loader=lambda x: x, is_valid_file=lambda x: False
) )
...@@ -1092,9 +1092,6 @@ class Kinetics400TestCase(datasets_utils.VideoDatasetTestCase): ...@@ -1092,9 +1092,6 @@ class Kinetics400TestCase(datasets_utils.VideoDatasetTestCase):
return num_videos_per_class * len(classes) return num_videos_per_class * len(classes)
def test_not_found_or_corrupted(self):
self.skipTest("Dataset currently does not handle the case of no found videos.")
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.HMDB51 DATASET_CLASS = datasets.HMDB51
......
...@@ -32,9 +32,43 @@ def is_image_file(filename: str) -> bool: ...@@ -32,9 +32,43 @@ def is_image_file(filename: str) -> bool:
return has_file_allowed_extension(filename, IMG_EXTENSIONS) return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset structured as follows:
.. code::
directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── ...
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── ...
└── asd932_.ext
Args:
directory (str): Root directory path.
Raises:
FileNotFoundError: If ``directory`` has no class folders.
Returns:
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def make_dataset( def make_dataset(
directory: str, directory: str,
class_to_idx: Dict[str, int], class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Tuple[str, ...]] = None, extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None, is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]: ) -> List[Tuple[str, int]]:
...@@ -42,7 +76,8 @@ def make_dataset( ...@@ -42,7 +76,8 @@ def make_dataset(
Args: Args:
directory (str): root dataset directory directory (str): root dataset directory
class_to_idx (Dict[str, int]): dictionary mapping class name to class index class_to_idx (Optional[Dict[str, int]]): Dictionary mapping class name to class index. If omitted, is generated
by :func:`find_classes`.
extensions (optional): A list of allowed extensions. extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None. Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file is_valid_file (optional): A function that takes path of a file
...@@ -51,21 +86,34 @@ def make_dataset( ...@@ -51,21 +86,34 @@ def make_dataset(
is_valid_file should not be passed. Defaults to None. is_valid_file should not be passed. Defaults to None.
Raises: Raises:
ValueError: In case ``class_to_idx`` is empty.
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
FileNotFoundError: In case no valid file was found for any class.
Returns: Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class) List[Tuple[str, int]]: samples of a form (path_to_sample, class)
""" """
instances = []
directory = os.path.expanduser(directory) directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
both_none = extensions is None and is_valid_file is None both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None both_something = extensions is not None and is_valid_file is not None
if both_none or both_something: if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None: if extensions is not None:
def is_valid_file(x: str) -> bool: def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file) is_valid_file = cast(Callable[[str], bool], is_valid_file)
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()): for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class] class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class) target_dir = os.path.join(directory, target_class)
...@@ -77,6 +125,17 @@ def make_dataset( ...@@ -77,6 +125,17 @@ def make_dataset(
if is_valid_file(path): if is_valid_file(path):
item = path, class_index item = path, class_index
instances.append(item) instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
empty_classes = available_classes - set(class_to_idx.keys())
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {', '.join(extensions)}"
raise FileNotFoundError(msg)
return instances return instances
...@@ -125,11 +184,6 @@ class DatasetFolder(VisionDataset): ...@@ -125,11 +184,6 @@ class DatasetFolder(VisionDataset):
target_transform=target_transform) target_transform=target_transform)
classes, class_to_idx = self._find_classes(self.root) classes, class_to_idx = self._find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0:
msg = "Found 0 files in subfolders of: {}\n".format(self.root)
if extensions is not None:
msg += "Supported extensions are: {}".format(",".join(extensions))
raise RuntimeError(msg)
self.loader = loader self.loader = loader
self.extensions = extensions self.extensions = extensions
...@@ -148,23 +202,9 @@ class DatasetFolder(VisionDataset): ...@@ -148,23 +202,9 @@ class DatasetFolder(VisionDataset):
) -> List[Tuple[str, int]]: ) -> List[Tuple[str, int]]:
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: @staticmethod
""" def _find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]:
Finds the class folders in a dataset. return find_classes(dir)
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def __getitem__(self, index: int) -> Tuple[Any, Any]: def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
......
...@@ -2,7 +2,7 @@ import glob ...@@ -2,7 +2,7 @@ import glob
import os import os
from .utils import list_dir from .utils import list_dir
from .folder import make_dataset from .folder import find_classes, make_dataset
from .video_utils import VideoClips from .video_utils import VideoClips
from .vision import VisionDataset from .vision import VisionDataset
...@@ -62,8 +62,7 @@ class HMDB51(VisionDataset): ...@@ -62,8 +62,7 @@ class HMDB51(VisionDataset):
raise ValueError("fold should be between 1 and 3, got {}".format(fold)) raise ValueError("fold should be between 1 and 3, got {}".format(fold))
extensions = ('avi',) extensions = ('avi',)
classes = sorted(list_dir(root)) self.classes, class_to_idx = find_classes(self.root)
class_to_idx = {class_: i for (i, class_) in enumerate(classes)}
self.samples = make_dataset( self.samples = make_dataset(
self.root, self.root,
class_to_idx, class_to_idx,
...@@ -89,7 +88,6 @@ class HMDB51(VisionDataset): ...@@ -89,7 +88,6 @@ class HMDB51(VisionDataset):
self.full_video_clips = video_clips self.full_video_clips = video_clips
self.fold = fold self.fold = fold
self.train = train self.train = train
self.classes = classes
self.indices = self._select_fold(video_paths, annotation_path, fold, train) self.indices = self._select_fold(video_paths, annotation_path, fold, train)
self.video_clips = video_clips.subset(self.indices) self.video_clips = video_clips.subset(self.indices)
self.transform = transform self.transform = transform
......
from .utils import list_dir from .utils import list_dir
from .folder import make_dataset from .folder import find_classes, make_dataset
from .video_utils import VideoClips from .video_utils import VideoClips
from .vision import VisionDataset from .vision import VisionDataset
...@@ -56,10 +56,8 @@ class Kinetics400(VisionDataset): ...@@ -56,10 +56,8 @@ class Kinetics400(VisionDataset):
_video_min_dimension=0, _audio_samples=0, _audio_channels=0): _video_min_dimension=0, _audio_samples=0, _audio_channels=0):
super(Kinetics400, self).__init__(root) super(Kinetics400, self).__init__(root)
classes = list(sorted(list_dir(root))) self.classes, class_to_idx = find_classes(self.root)
class_to_idx = {classes[i]: i for i in range(len(classes))}
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None) self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
video_list = [x[0] for x in self.samples] video_list = [x[0] for x in self.samples]
self.video_clips = VideoClips( self.video_clips = VideoClips(
video_list, video_list,
......
import os import os
from .utils import list_dir from .utils import list_dir
from .folder import make_dataset from .folder import find_classes, make_dataset
from .video_utils import VideoClips from .video_utils import VideoClips
from .vision import VisionDataset from .vision import VisionDataset
...@@ -55,10 +55,8 @@ class UCF101(VisionDataset): ...@@ -55,10 +55,8 @@ class UCF101(VisionDataset):
self.fold = fold self.fold = fold
self.train = train self.train = train
classes = list(sorted(list_dir(root))) self.classes, class_to_idx = find_classes(self.root)
class_to_idx = {classes[i]: i for i in range(len(classes))}
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None) self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
video_list = [x[0] for x in self.samples] video_list = [x[0] for x in self.samples]
video_clips = VideoClips( video_clips = VideoClips(
video_list, video_list,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment