Unverified Commit 6512146e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

allow single extension as str in make_dataset (#5229)

* allow single extension as str in make_dataset

* remove test class

* remove regex

* revert collection to tuple

* cleanup
parent c27bed45
import contextlib import contextlib
import gzip import gzip
import os import os
import pathlib
import re
import tarfile import tarfile
import zipfile import zipfile
import pytest import pytest
import torchvision.datasets.utils as utils import torchvision.datasets.utils as utils
from torch._utils_internal import get_file_path_2 from torch._utils_internal import get_file_path_2
from torchvision.datasets.folder import make_dataset
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
TEST_FILE = get_file_path_2( TEST_FILE = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg" os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
) )
...@@ -214,5 +216,29 @@ class TestDatasetsUtils: ...@@ -214,5 +216,29 @@ class TestDatasetsUtils:
pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
@pytest.mark.parametrize(
("kwargs", "expected_error_msg"),
[
(dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"),
(dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")),
(dict(extensions=(".png", ".jpeg")), re.escape("classes c. Supported extensions are: .png, .jpeg")),
],
)
def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg):
tmpdir = pathlib.Path(tmpdir)
(tmpdir / "a").mkdir()
(tmpdir / "a" / "a.png").touch()
(tmpdir / "b").mkdir()
(tmpdir / "b" / "b.jpeg").touch()
(tmpdir / "c").mkdir()
(tmpdir / "c" / "c.unknown").touch()
with pytest.raises(FileNotFoundError, match=expected_error_msg):
make_dataset(str(tmpdir), **kwargs)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
import os import os
import os.path import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Union
from PIL import Image from PIL import Image
from .vision import VisionDataset from .vision import VisionDataset
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
"""Checks if a file is an allowed extension. """Checks if a file is an allowed extension.
Args: Args:
...@@ -17,7 +18,7 @@ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bo ...@@ -17,7 +18,7 @@ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bo
Returns: Returns:
bool: True if the filename ends with one of given extensions bool: True if the filename ends with one of given extensions
""" """
return filename.lower().endswith(extensions) return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
def is_image_file(filename: str) -> bool: def is_image_file(filename: str) -> bool:
...@@ -48,7 +49,7 @@ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: ...@@ -48,7 +49,7 @@ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
def make_dataset( def make_dataset(
directory: str, directory: str,
class_to_idx: Optional[Dict[str, int]] = None, class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Tuple[str, ...]] = None, extensions: Optional[Union[str, Tuple[str, ...]]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None, is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]: ) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class). """Generates a list of samples of a form (path_to_sample, class).
...@@ -73,7 +74,7 @@ def make_dataset( ...@@ -73,7 +74,7 @@ def make_dataset(
if extensions is not None: if extensions is not None:
def is_valid_file(x: str) -> bool: def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
is_valid_file = cast(Callable[[str], bool], is_valid_file) is_valid_file = cast(Callable[[str], bool], is_valid_file)
...@@ -98,7 +99,7 @@ def make_dataset( ...@@ -98,7 +99,7 @@ def make_dataset(
if empty_classes: if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None: if extensions is not None:
msg += f"Supported extensions are: {', '.join(extensions)}" msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
raise FileNotFoundError(msg) raise FileNotFoundError(msg)
return instances return instances
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment