Unverified Commit b94004a6 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Minor changes to prototype datasets (#5282)



* Some Qs

* Some modifications

* don't need _loader in __init__

* list_names -> list_datasets

* Update torchvision/prototype/datasets/utils/_resource.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Remove unsued import

* fix tests

* Some missing renames
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent f03ca0f9
......@@ -18,7 +18,7 @@ def test_home(mocker, tmp_path):
def test_coverage():
untested_datasets = set(datasets.list()) - DATASET_MOCKS.keys()
untested_datasets = set(datasets.list_datasets()) - DATASET_MOCKS.keys()
if untested_datasets:
raise AssertionError(
f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} "
......
......@@ -11,5 +11,5 @@ from . import decoder, utils
from ._home import home
# Load this last, since some parts depend on the above being loaded first
from ._api import register, _list as list, info, load, find # usort: skip
from ._api import register, list_datasets, info, load, find # usort: skip
from ._folder import from_data_folder, from_image_folder
......@@ -23,8 +23,7 @@ for name, obj in _builtin.__dict__.items():
register(obj())
# This is exposed as 'list', but we avoid that here to not shadow the built-in 'list'
def _list() -> List[str]:
def list_datasets() -> List[str]:
return sorted(DATASETS.keys())
......@@ -39,7 +38,7 @@ def find(name: str) -> Dataset:
word=name,
possibilities=DATASETS.keys(),
alternative_hint=lambda _: (
"You can use torchvision.datasets.list() to get a list of all available datasets."
"You can use torchvision.datasets.list_datasets() to get a list of all available datasets."
),
)
) from error
......
......@@ -49,7 +49,7 @@ def parse_args(argv=None):
args = parser.parse_args(argv or sys.argv[1:])
if not args.names:
args.names = datasets.list()
args.names = datasets.list_datasets()
return args
......
......@@ -24,6 +24,8 @@ class DatasetType(enum.Enum):
class DatasetConfig(FrozenBunch):
# This needs to be Frozen because we often pass configs as partial(func, config=config)
# and partial() requires the parameters to be hashable.
pass
......
......@@ -39,7 +39,6 @@ __all__ = [
"BUILTIN_DIR",
"read_mat",
"image_buffer_from_array",
"SequenceIterator",
"MappingIterator",
"Enumerator",
"getitem",
......@@ -80,15 +79,6 @@ def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.Byt
return buffer
class SequenceIterator(IterDataPipe[D]):
def __init__(self, datapipe: IterDataPipe[Sequence[D]]):
self.datapipe = datapipe
def __iter__(self) -> Iterator[D]:
for sequence in self.datapipe:
yield from iter(sequence)
class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None:
self.datapipe = datapipe
......
......@@ -2,7 +2,6 @@ import abc
import hashlib
import itertools
import pathlib
import warnings
from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn
from urllib.parse import urlparse
......@@ -32,23 +31,17 @@ class OnlineResource(abc.ABC):
sha256: Optional[str] = None,
decompress: bool = False,
extract: bool = False,
preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]] = None,
loader: Optional[Callable[[pathlib.Path], IterDataPipe[Tuple[str, IO]]]] = None,
) -> None:
self.file_name = file_name
self.sha256 = sha256
if preprocess and (decompress or extract):
warnings.warn("The parameters 'decompress' and 'extract' are ignored when 'preprocess' is passed.")
elif extract:
preprocess = self._extract
self._preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]]
if extract:
self._preprocess = self._extract
elif decompress:
preprocess = self._decompress
self._preprocess = preprocess
if loader is None:
loader = self._default_loader
self._loader = loader
self._preprocess = self._decompress
else:
self._preprocess = None
@staticmethod
def _extract(file: pathlib.Path) -> pathlib.Path:
......@@ -60,7 +53,7 @@ class OnlineResource(abc.ABC):
def _decompress(file: pathlib.Path) -> pathlib.Path:
return pathlib.Path(_decompress(str(file), remove_finished=True))
def _default_loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]:
def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]:
if path.is_dir():
return FileOpener(FileLister(str(path), recursive=True), mode="rb")
......@@ -101,7 +94,7 @@ class OnlineResource(abc.ABC):
path_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)}
# If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps.
if path_candidates == {path}:
if self._preprocess:
if self._preprocess is not None:
path = self._preprocess(path)
# Otherwise we use the path with the fewest suffixes. This gives us the extracted > decompressed > raw priority
# that we want.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment