Unverified Commit 6512146e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

allow single extension as str in make_dataset (#5229)

* allow single extension as str in make_dataset

* remove test class

* remove regex

* revert collection to tuple

* cleanup
parent c27bed45
import contextlib
import gzip
import os
import pathlib
import re
import tarfile
import zipfile
import pytest
import torchvision.datasets.utils as utils
from torch._utils_internal import get_file_path_2
from torchvision.datasets.folder import make_dataset
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
TEST_FILE = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)
......@@ -214,5 +216,29 @@ class TestDatasetsUtils:
pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
@pytest.mark.parametrize(
("kwargs", "expected_error_msg"),
[
(dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"),
(dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")),
(dict(extensions=(".png", ".jpeg")), re.escape("classes c. Supported extensions are: .png, .jpeg")),
],
)
def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg):
tmpdir = pathlib.Path(tmpdir)
(tmpdir / "a").mkdir()
(tmpdir / "a" / "a.png").touch()
(tmpdir / "b").mkdir()
(tmpdir / "b" / "b.jpeg").touch()
(tmpdir / "c").mkdir()
(tmpdir / "c" / "c.unknown").touch()
with pytest.raises(FileNotFoundError, match=expected_error_msg):
make_dataset(str(tmpdir), **kwargs)
if __name__ == "__main__":
pytest.main([__file__])
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Union
from PIL import Image
from .vision import VisionDataset
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
"""Checks if a file is an allowed extension.
Args:
......@@ -17,7 +18,7 @@ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bo
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
def is_image_file(filename: str) -> bool:
......@@ -48,7 +49,7 @@ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
def make_dataset(
directory: str,
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Tuple[str, ...]] = None,
extensions: Optional[Union[str, Tuple[str, ...]]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
......@@ -73,7 +74,7 @@ def make_dataset(
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
is_valid_file = cast(Callable[[str], bool], is_valid_file)
......@@ -98,7 +99,7 @@ def make_dataset(
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {', '.join(extensions)}"
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
raise FileNotFoundError(msg)
return instances
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment