Unverified Commit 9563e3e3 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add allow_empty parameter to ImageFolder and related utils (#8311)

parent e00f4e66
...@@ -1620,6 +1620,10 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1620,6 +1620,10 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
num_examples_total += num_examples num_examples_total += num_examples
classes.append(cls) classes.append(cls)
if config.pop("make_empty_class", False):
os.makedirs(pathlib.Path(tmpdir) / "empty_class")
classes.append("empty_class")
return dict(num_examples=num_examples_total, classes=classes) return dict(num_examples=num_examples_total, classes=classes)
def _file_name_fn(self, cls, ext, idx): def _file_name_fn(self, cls, ext, idx):
...@@ -1644,6 +1648,23 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1644,6 +1648,23 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
assert len(dataset.classes) == len(info["classes"]) assert len(dataset.classes) == len(info["classes"])
assert all([a == b for a, b in zip(dataset.classes, info["classes"])]) assert all([a == b for a, b in zip(dataset.classes, info["classes"])])
def test_allow_empty(self):
config = {
"extensions": self._EXTENSIONS,
"make_empty_class": True,
}
config["allow_empty"] = True
with self.create_dataset(config) as (dataset, info):
assert "empty_class" in dataset.classes
assert len(dataset.classes) == len(info["classes"])
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])
config["allow_empty"] = False
with pytest.raises(FileNotFoundError, match="Found no valid file"):
with self.create_dataset(config) as (dataset, info):
pass
class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase): class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ImageFolder DATASET_CLASS = datasets.ImageFolder
......
...@@ -50,6 +50,7 @@ def make_dataset( ...@@ -50,6 +50,7 @@ def make_dataset(
class_to_idx: Optional[Dict[str, int]] = None, class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Union[str, Tuple[str, ...]]] = None, extensions: Optional[Union[str, Tuple[str, ...]]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None, is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
) -> List[Tuple[str, int]]: ) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class). """Generates a list of samples of a form (path_to_sample, class).
...@@ -95,7 +96,7 @@ def make_dataset( ...@@ -95,7 +96,7 @@ def make_dataset(
available_classes.add(target_class) available_classes.add(target_class)
empty_classes = set(class_to_idx.keys()) - available_classes empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes: if empty_classes and not allow_empty:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None: if extensions is not None:
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
...@@ -123,6 +124,8 @@ class DatasetFolder(VisionDataset): ...@@ -123,6 +124,8 @@ class DatasetFolder(VisionDataset):
is_valid_file (callable, optional): A function that takes path of a file is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files) and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed. both extensions and is_valid_file should not be passed.
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
An error is raised on empty folders if False (default).
Attributes: Attributes:
classes (list): List of the class names sorted alphabetically. classes (list): List of the class names sorted alphabetically.
...@@ -139,10 +142,17 @@ class DatasetFolder(VisionDataset): ...@@ -139,10 +142,17 @@ class DatasetFolder(VisionDataset):
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None, is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
) -> None: ) -> None:
super().__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
classes, class_to_idx = self.find_classes(self.root) classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) samples = self.make_dataset(
self.root,
class_to_idx=class_to_idx,
extensions=extensions,
is_valid_file=is_valid_file,
allow_empty=allow_empty,
)
self.loader = loader self.loader = loader
self.extensions = extensions self.extensions = extensions
...@@ -158,6 +168,7 @@ class DatasetFolder(VisionDataset): ...@@ -158,6 +168,7 @@ class DatasetFolder(VisionDataset):
class_to_idx: Dict[str, int], class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None, extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None, is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
) -> List[Tuple[str, int]]: ) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class). """Generates a list of samples of a form (path_to_sample, class).
...@@ -172,6 +183,8 @@ class DatasetFolder(VisionDataset): ...@@ -172,6 +183,8 @@ class DatasetFolder(VisionDataset):
and checks if the file is a valid file and checks if the file is a valid file
(used to check of corrupt files) both extensions and (used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None. is_valid_file should not be passed. Defaults to None.
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
An error is raised on empty folders if False (default).
Raises: Raises:
ValueError: In case ``class_to_idx`` is empty. ValueError: In case ``class_to_idx`` is empty.
...@@ -186,7 +199,9 @@ class DatasetFolder(VisionDataset): ...@@ -186,7 +199,9 @@ class DatasetFolder(VisionDataset):
# find_classes() function, instead of using that of the find_classes() method, which # find_classes() function, instead of using that of the find_classes() method, which
# is potentially overridden and thus could have a different logic. # is potentially overridden and thus could have a different logic.
raise ValueError("The class_to_idx parameter cannot be None.") raise ValueError("The class_to_idx parameter cannot be None.")
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) return make_dataset(
directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file, allow_empty=allow_empty
)
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]: def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Find the class folders in a dataset structured as follows:: """Find the class folders in a dataset structured as follows::
...@@ -291,6 +306,8 @@ class ImageFolder(DatasetFolder): ...@@ -291,6 +306,8 @@ class ImageFolder(DatasetFolder):
loader (callable, optional): A function to load an image given its path. loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files) and check if the file is a valid file (used to check of corrupt files)
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
An error is raised on empty folders if False (default).
Attributes: Attributes:
classes (list): List of the class names sorted alphabetically. classes (list): List of the class names sorted alphabetically.
...@@ -305,6 +322,7 @@ class ImageFolder(DatasetFolder): ...@@ -305,6 +322,7 @@ class ImageFolder(DatasetFolder):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader, loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None, is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
): ):
super().__init__( super().__init__(
root, root,
...@@ -313,5 +331,6 @@ class ImageFolder(DatasetFolder): ...@@ -313,5 +331,6 @@ class ImageFolder(DatasetFolder):
transform=transform, transform=transform,
target_transform=target_transform, target_transform=target_transform,
is_valid_file=is_valid_file, is_valid_file=is_valid_file,
allow_empty=allow_empty,
) )
self.imgs = self.samples self.imgs = self.samples
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment