Unverified Commit b94004a6 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Minor changes to prototype datasets (#5282)



* Some Qs

* Some modifications

* don't need _loader in __init__

* list_names -> list_datasets

* Update torchvision/prototype/datasets/utils/_resource.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Remove unsued import

* fix tests

* Some missing renames
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent f03ca0f9
...@@ -18,7 +18,7 @@ def test_home(mocker, tmp_path): ...@@ -18,7 +18,7 @@ def test_home(mocker, tmp_path):
def test_coverage(): def test_coverage():
untested_datasets = set(datasets.list()) - DATASET_MOCKS.keys() untested_datasets = set(datasets.list_datasets()) - DATASET_MOCKS.keys()
if untested_datasets: if untested_datasets:
raise AssertionError( raise AssertionError(
f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} " f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} "
......
...@@ -11,5 +11,5 @@ from . import decoder, utils ...@@ -11,5 +11,5 @@ from . import decoder, utils
from ._home import home from ._home import home
# Load this last, since some parts depend on the above being loaded first # Load this last, since some parts depend on the above being loaded first
from ._api import register, _list as list, info, load, find # usort: skip from ._api import register, list_datasets, info, load, find # usort: skip
from ._folder import from_data_folder, from_image_folder from ._folder import from_data_folder, from_image_folder
...@@ -23,8 +23,7 @@ for name, obj in _builtin.__dict__.items(): ...@@ -23,8 +23,7 @@ for name, obj in _builtin.__dict__.items():
register(obj()) register(obj())
# This is exposed as 'list', but we avoid that here to not shadow the built-in 'list' def list_datasets() -> List[str]:
def _list() -> List[str]:
return sorted(DATASETS.keys()) return sorted(DATASETS.keys())
...@@ -39,7 +38,7 @@ def find(name: str) -> Dataset: ...@@ -39,7 +38,7 @@ def find(name: str) -> Dataset:
word=name, word=name,
possibilities=DATASETS.keys(), possibilities=DATASETS.keys(),
alternative_hint=lambda _: ( alternative_hint=lambda _: (
"You can use torchvision.datasets.list() to get a list of all available datasets." "You can use torchvision.datasets.list_datasets() to get a list of all available datasets."
), ),
) )
) from error ) from error
......
...@@ -49,7 +49,7 @@ def parse_args(argv=None): ...@@ -49,7 +49,7 @@ def parse_args(argv=None):
args = parser.parse_args(argv or sys.argv[1:]) args = parser.parse_args(argv or sys.argv[1:])
if not args.names: if not args.names:
args.names = datasets.list() args.names = datasets.list_datasets()
return args return args
......
...@@ -24,6 +24,8 @@ class DatasetType(enum.Enum): ...@@ -24,6 +24,8 @@ class DatasetType(enum.Enum):
class DatasetConfig(FrozenBunch): class DatasetConfig(FrozenBunch):
# This needs to be Frozen because we often pass configs as partial(func, config=config)
# and partial() requires the parameters to be hashable.
pass pass
......
...@@ -39,7 +39,6 @@ __all__ = [ ...@@ -39,7 +39,6 @@ __all__ = [
"BUILTIN_DIR", "BUILTIN_DIR",
"read_mat", "read_mat",
"image_buffer_from_array", "image_buffer_from_array",
"SequenceIterator",
"MappingIterator", "MappingIterator",
"Enumerator", "Enumerator",
"getitem", "getitem",
...@@ -80,15 +79,6 @@ def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.Byt ...@@ -80,15 +79,6 @@ def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.Byt
return buffer return buffer
class SequenceIterator(IterDataPipe[D]):
def __init__(self, datapipe: IterDataPipe[Sequence[D]]):
self.datapipe = datapipe
def __iter__(self) -> Iterator[D]:
for sequence in self.datapipe:
yield from iter(sequence)
class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]): class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None: def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None:
self.datapipe = datapipe self.datapipe = datapipe
......
...@@ -2,7 +2,6 @@ import abc ...@@ -2,7 +2,6 @@ import abc
import hashlib import hashlib
import itertools import itertools
import pathlib import pathlib
import warnings
from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn
from urllib.parse import urlparse from urllib.parse import urlparse
...@@ -32,23 +31,17 @@ class OnlineResource(abc.ABC): ...@@ -32,23 +31,17 @@ class OnlineResource(abc.ABC):
sha256: Optional[str] = None, sha256: Optional[str] = None,
decompress: bool = False, decompress: bool = False,
extract: bool = False, extract: bool = False,
preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]] = None,
loader: Optional[Callable[[pathlib.Path], IterDataPipe[Tuple[str, IO]]]] = None,
) -> None: ) -> None:
self.file_name = file_name self.file_name = file_name
self.sha256 = sha256 self.sha256 = sha256
if preprocess and (decompress or extract): self._preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]]
warnings.warn("The parameters 'decompress' and 'extract' are ignored when 'preprocess' is passed.") if extract:
elif extract: self._preprocess = self._extract
preprocess = self._extract
elif decompress: elif decompress:
preprocess = self._decompress self._preprocess = self._decompress
self._preprocess = preprocess else:
self._preprocess = None
if loader is None:
loader = self._default_loader
self._loader = loader
@staticmethod @staticmethod
def _extract(file: pathlib.Path) -> pathlib.Path: def _extract(file: pathlib.Path) -> pathlib.Path:
...@@ -60,7 +53,7 @@ class OnlineResource(abc.ABC): ...@@ -60,7 +53,7 @@ class OnlineResource(abc.ABC):
def _decompress(file: pathlib.Path) -> pathlib.Path: def _decompress(file: pathlib.Path) -> pathlib.Path:
return pathlib.Path(_decompress(str(file), remove_finished=True)) return pathlib.Path(_decompress(str(file), remove_finished=True))
def _default_loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]:
if path.is_dir(): if path.is_dir():
return FileOpener(FileLister(str(path), recursive=True), mode="rb") return FileOpener(FileLister(str(path), recursive=True), mode="rb")
...@@ -101,7 +94,7 @@ class OnlineResource(abc.ABC): ...@@ -101,7 +94,7 @@ class OnlineResource(abc.ABC):
path_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)} path_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)}
# If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps. # If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps.
if path_candidates == {path}: if path_candidates == {path}:
if self._preprocess: if self._preprocess is not None:
path = self._preprocess(path) path = self._preprocess(path)
# Otherwise we use the path with the fewest suffixes. This gives us the extracted > decompressed > raw priority # Otherwise we use the path with the fewest suffixes. This gives us the extracted > decompressed > raw priority
# that we want. # that we want.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment