Unverified Commit a068602e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

cleanup prototype datasets (#4471)



* cleanup image folder

* make shuffling mandatory

* rename parameter in home() function

* don't show builtin list

* make categories optional in dataset info

* use pseudo-infinite buffer size for shuffler
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 932ca5a3
...@@ -2,5 +2,5 @@ from ._home import home ...@@ -2,5 +2,5 @@ from ._home import home
from . import decoder, utils from . import decoder, utils
# Load this last, since some parts depend on the above being loaded first # Load this last, since some parts depend on the above being loaded first
from ._api import register, list, info, load from ._api import register, _list as list, info, load
from ._folder import from_data_folder, from_image_folder from ._folder import from_data_folder, from_image_folder
...@@ -17,7 +17,8 @@ def register(dataset: Dataset) -> None: ...@@ -17,7 +17,8 @@ def register(dataset: Dataset) -> None:
DATASETS[dataset.name] = dataset DATASETS[dataset.name] = dataset
def list() -> List[str]: # This is exposed as 'list', but we avoid that here to not shadow the built-in 'list'
def _list() -> List[str]:
return sorted(DATASETS.keys()) return sorted(DATASETS.keys())
...@@ -45,7 +46,6 @@ def info(name: str) -> DatasetInfo: ...@@ -45,7 +46,6 @@ def info(name: str) -> DatasetInfo:
def load( def load(
name: str, name: str,
*, *,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil,
split: str = "train", split: str = "train",
**options: Any, **options: Any,
...@@ -55,4 +55,4 @@ def load( ...@@ -55,4 +55,4 @@ def load(
config = dataset.info.make_config(split=split, **options) config = dataset.info.make_config(split=split, **options)
root = home() / name root = home() / name
return dataset.to_datapipe(root, config=config, shuffler=shuffler, decoder=decoder) return dataset.to_datapipe(root, config=config, decoder=decoder)
...@@ -10,13 +10,11 @@ from torch.utils.data import IterDataPipe ...@@ -10,13 +10,11 @@ from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter
from torchvision.prototype.datasets.decoder import pil from torchvision.prototype.datasets.decoder import pil
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
__all__ = ["from_data_folder", "from_image_folder"] __all__ = ["from_data_folder", "from_image_folder"]
# pseudo-infinite buffer size until a true infinite buffer is supported
INFINITE = 1_000_000_000
def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool: def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool:
rel_path = pathlib.Path(path).relative_to(root) rel_path = pathlib.Path(path).relative_to(root)
...@@ -45,7 +43,6 @@ def _collate_and_decode_data( ...@@ -45,7 +43,6 @@ def _collate_and_decode_data(
def from_data_folder( def from_data_folder(
root: Union[str, pathlib.Path], root: Union[str, pathlib.Path],
*, *,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = lambda dp: Shuffler(dp, buffer_size=INFINITE),
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
valid_extensions: Optional[Collection[str]] = None, valid_extensions: Optional[Collection[str]] = None,
recursive: bool = True, recursive: bool = True,
...@@ -55,8 +52,7 @@ def from_data_folder( ...@@ -55,8 +52,7 @@ def from_data_folder(
masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else "" masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else ""
dp: IterDataPipe = FileLister(str(root), recursive=recursive, masks=masks) dp: IterDataPipe = FileLister(str(root), recursive=recursive, masks=masks)
dp = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root)) dp = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root))
if shuffler: dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = shuffler(dp)
dp = FileLoader(dp) dp = FileLoader(dp)
return ( return (
Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)), Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)),
......
...@@ -7,14 +7,14 @@ from torch.hub import _get_torch_home ...@@ -7,14 +7,14 @@ from torch.hub import _get_torch_home
HOME = pathlib.Path(_get_torch_home()) / "datasets" / "vision" HOME = pathlib.Path(_get_torch_home()) / "datasets" / "vision"
def home(home: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path: def home(root: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path:
global HOME global HOME
if home is not None: if root is not None:
HOME = pathlib.Path(home).expanduser().resolve() HOME = pathlib.Path(root).expanduser().resolve()
return HOME return HOME
home = os.getenv("TORCHVISION_DATASETS_HOME") root = os.getenv("TORCHVISION_DATASETS_HOME")
if home is not None: if root is not None:
return pathlib.Path(home) return pathlib.Path(root)
return HOME return HOME
import io import io
import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from torchvision.transforms.functional import pil_to_tensor
__all__ = ["pil"] __all__ = ["pil"]
def pil(file: io.IOBase, mode: str = "RGB") -> torch.Tensor: def pil(file: io.IOBase, mode: str = "RGB") -> torch.Tensor:
image = PIL.Image.open(file).convert(mode.upper()) return pil_to_tensor(PIL.Image.open(file).convert(mode.upper()))
return torch.from_numpy(np.array(image, copy=True)).permute((2, 0, 1))
...@@ -98,7 +98,7 @@ class DatasetInfo: ...@@ -98,7 +98,7 @@ class DatasetInfo:
self, self,
name: str, name: str,
*, *,
categories: Union[int, Sequence[str], str, pathlib.Path], categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None,
citation: Optional[str] = None, citation: Optional[str] = None,
homepage: Optional[str] = None, homepage: Optional[str] = None,
license: Optional[str] = None, license: Optional[str] = None,
...@@ -106,7 +106,9 @@ class DatasetInfo: ...@@ -106,7 +106,9 @@ class DatasetInfo:
) -> None: ) -> None:
self.name = name.lower() self.name = name.lower()
if isinstance(categories, int): if categories is None:
categories = []
elif isinstance(categories, int):
categories = [str(label) for label in range(categories)] categories = [str(label) for label in range(categories)]
elif isinstance(categories, (str, pathlib.Path)): elif isinstance(categories, (str, pathlib.Path)):
with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh: with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh:
...@@ -198,7 +200,6 @@ class Dataset(abc.ABC): ...@@ -198,7 +200,6 @@ class Dataset(abc.ABC):
resource_dps: List[IterDataPipe], resource_dps: List[IterDataPipe],
*, *,
config: DatasetConfig, config: DatasetConfig,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]],
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
pass pass
...@@ -208,7 +209,6 @@ class Dataset(abc.ABC): ...@@ -208,7 +209,6 @@ class Dataset(abc.ABC):
root: Union[str, pathlib.Path], root: Union[str, pathlib.Path],
*, *,
config: Optional[DatasetConfig] = None, config: Optional[DatasetConfig] = None,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
if not config: if not config:
...@@ -217,4 +217,4 @@ class Dataset(abc.ABC): ...@@ -217,4 +217,4 @@ class Dataset(abc.ABC):
resource_dps = [ resource_dps = [
resource.to_datapipe(root) for resource in self.resources(config) resource.to_datapipe(root) for resource in self.resources(config)
] ]
return self._make_datapipe(resource_dps, config=config, shuffler=shuffler, decoder=decoder) return self._make_datapipe(resource_dps, config=config, decoder=decoder)
...@@ -4,10 +4,14 @@ from typing import Collection, Sequence, Callable ...@@ -4,10 +4,14 @@ from typing import Collection, Sequence, Callable
__all__ = [ __all__ = [
"INFINITE_BUFFER_SIZE",
"sequence_to_str", "sequence_to_str",
"add_suggestion", "add_suggestion",
] ]
# pseudo-infinite until a true infinite buffer is supported by all datapipes
INFINITE_BUFFER_SIZE = 1_000_000_000
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if len(seq) == 1: if len(seq) == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment