Unverified Commit 92524362 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add API for new style datasets (#4473)



* add API for new style datasets

* cleanup
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 3f9b2d9c
from . import decoder
from ._home import home
from . import decoder, utils
# Load this last, since some parts depend on the above being loaded first
from ._api import register, list, info, load
from ._folder import from_data_folder, from_image_folder
import io
from typing import Any, Callable, Dict, List, Optional
import torch
from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.decoder import pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo
from torchvision.prototype.datasets.utils._internal import add_suggestion
DATASETS: Dict[str, Dataset] = {}
def register(dataset: Dataset) -> None:
DATASETS[dataset.name] = dataset
def list() -> List[str]:
return sorted(DATASETS.keys())
def find(name: str) -> Dataset:
name = name.lower()
try:
return DATASETS[name]
except KeyError as error:
raise ValueError(
add_suggestion(
f"Unknown dataset '{name}'.",
word=name,
possibilities=DATASETS.keys(),
alternative_hint=lambda _: (
"You can use torchvision.datasets.list() to get a list of all available datasets."
),
)
) from error
def info(name: str) -> DatasetInfo:
return find(name).info
def load(
name: str,
*,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil,
split: str = "train",
**options: Any,
) -> IterDataPipe[Dict[str, Any]]:
dataset = find(name)
config = dataset.info.make_config(split=split, **options)
root = home() / name
return dataset.to_datapipe(root, config=config, shuffler=shuffler, decoder=decoder)
import os
import pathlib
from typing import Optional, Union
from torch.hub import _get_torch_home
HOME = pathlib.Path(_get_torch_home()) / "datasets" / "vision"
def home(home: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path:
global HOME
if home is not None:
HOME = pathlib.Path(home).expanduser().resolve()
return HOME
home = os.getenv("TORCHVISION_DATASETS_HOME")
if home is not None:
return pathlib.Path(home)
return HOME
......@@ -7,6 +7,6 @@ import torch
__all__ = ["pil"]
def pil(file: io.IOBase, mode="RGB") -> torch.Tensor:
def pil(file: io.IOBase, mode: str = "RGB") -> torch.Tensor:
image = PIL.Image.open(file).convert(mode.upper())
return torch.from_numpy(np.array(image, copy=True)).permute((2, 0, 1))
from . import _internal
from ._dataset import DatasetConfig, DatasetInfo, Dataset
from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource
import abc
import io
import os
import pathlib
import textwrap
from collections import Mapping
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Union,
NoReturn,
Iterable,
Tuple,
)
import torch
from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets.utils._internal import (
add_suggestion,
sequence_to_str,
)
from ._resource import OnlineResource
def make_repr(name: str, items: Iterable[Tuple[str, Any]]):
def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items])
prefix = f"{name}("
postfix = ")"
body = to_str(", ")
line_length = int(os.environ.get("COLUMNS", 80))
body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length
multiline_body = len(str(body).splitlines()) > 1
if not (body_too_long or multiline_body):
return prefix + body + postfix
body = textwrap.indent(to_str(",\n"), " " * 2)
return f"{prefix}\n{body}\n{postfix}"
class DatasetConfig(Mapping):
def __init__(self, *args, **kwargs):
data = dict(*args, **kwargs)
self.__dict__["__data__"] = data
self.__dict__["__final_hash__"] = hash(tuple(data.items()))
def __getitem__(self, name: str) -> Any:
return self.__dict__["__data__"][name]
def __iter__(self):
return iter(self.__dict__["__data__"].keys())
def __len__(self):
return len(self.__dict__["__data__"])
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError as error:
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
) from error
def __setitem__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __setattr__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __delitem__(self, key: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __delattr__(self, item: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __hash__(self) -> int:
return self.__dict__["__final_hash__"]
def __eq__(self, other: Any) -> bool:
if not isinstance(other, DatasetConfig):
return NotImplemented
return hash(self) == hash(other)
def __repr__(self) -> str:
return make_repr(type(self).__name__, self.items())
class DatasetInfo:
def __init__(
self,
name: str,
*,
categories: Union[int, Sequence[str], str, pathlib.Path],
citation: Optional[str] = None,
homepage: Optional[str] = None,
license: Optional[str] = None,
valid_options: Optional[Dict[str, Sequence]] = None,
) -> None:
self.name = name.lower()
if isinstance(categories, int):
categories = [str(label) for label in range(categories)]
elif isinstance(categories, (str, pathlib.Path)):
with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh:
categories = fh.readlines()
self.categories = categories
self.citation = citation
self.homepage = homepage
self.license = license
valid_split: Dict[str, Sequence] = dict(split=["train"])
if valid_options is None:
valid_options = valid_split
elif "split" not in valid_options:
valid_options.update(valid_split)
elif "train" not in valid_options["split"]:
raise ValueError(
f"'train' has to be a valid argument for option 'split', "
f"but found only {sequence_to_str(valid_options['split'], separate_last='and ')}."
)
self._valid_options: Dict[str, Sequence] = valid_options
@property
def default_config(self) -> DatasetConfig:
return DatasetConfig(
{name: valid_args[0] for name, valid_args in self._valid_options.items()}
)
def make_config(self, **options: Any) -> DatasetConfig:
for name, arg in options.items():
if name not in self._valid_options:
raise ValueError(
add_suggestion(
f"Unknown option '{name}' of dataset {self.name}.",
word=name,
possibilities=sorted(self._valid_options.keys()),
)
)
valid_args = self._valid_options[name]
if arg not in valid_args:
raise ValueError(
add_suggestion(
f"Invalid argument '{arg}' for option '{name}' of dataset {self.name}.",
word=arg,
possibilities=valid_args,
)
)
return DatasetConfig(self.default_config, **options)
def __repr__(self) -> str:
items = [("name", self.name)]
for key in ("citation", "homepage", "license"):
value = getattr(self, key)
if value is not None:
items.append((key, value))
items.extend(
sorted(
(key, sequence_to_str(value))
for key, value in self._valid_options.items()
)
)
return make_repr(type(self).__name__, items)
class Dataset(abc.ABC):
@property
@abc.abstractmethod
def info(self) -> DatasetInfo:
pass
@property
def name(self) -> str:
return self.info.name
@property
def default_config(self) -> DatasetConfig:
return self.info.default_config
@abc.abstractmethod
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
pass
@abc.abstractmethod
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]],
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
pass
def to_datapipe(
self,
root: Union[str, pathlib.Path],
*,
config: Optional[DatasetConfig] = None,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
) -> IterDataPipe[Dict[str, Any]]:
if not config:
config = self.info.default_config
resource_dps = [
resource.to_datapipe(root) for resource in self.resources(config)
]
return self._make_datapipe(resource_dps, config=config, shuffler=shuffler, decoder=decoder)
import collections.abc
import difflib
from typing import Collection, Sequence, Callable
__all__ = [
"sequence_to_str",
"add_suggestion",
]
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if len(seq) == 1:
return f"'{seq[0]}'"
return (
f"""'{"', '".join([str(item) for item in seq[:-1]])}', """
f"""{separate_last}'{seq[-1]}'."""
)
def add_suggestion(
msg: str,
*,
word: str,
possibilities: Collection[str],
close_match_hint: Callable[
[str], str
] = lambda close_match: f"Did you mean '{close_match}'?",
alternative_hint: Callable[
[Sequence[str]], str
] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.",
) -> str:
if not isinstance(possibilities, collections.abc.Sequence):
possibilities = sorted(possibilities)
suggestions = difflib.get_close_matches(word, possibilities, 1)
hint = (
close_match_hint(suggestions[0])
if suggestions
else alternative_hint(possibilities)
)
return f"{msg.strip()} {hint}"
import os.path
import pathlib
from typing import Optional, Union
from urllib.parse import urlparse
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import FileLoader, IterableWrapper
# FIXME
def compute_sha256(_) -> str:
return ""
class LocalResource:
def __init__(
self, path: Union[str, pathlib.Path], *, sha256: Optional[str] = None
) -> None:
self.path = pathlib.Path(path).expanduser().resolve()
self.file_name = self.path.name
self.sha256 = sha256 or compute_sha256(self.path)
def to_datapipe(self) -> IterDataPipe:
return FileLoader(IterableWrapper((str(self.path),)))
class OnlineResource:
def __init__(self, url: str, *, sha256: str, file_name: str) -> None:
self.url = url
self.sha256 = sha256
self.file_name = file_name
def to_datapipe(self, root: Union[str, pathlib.Path]) -> IterDataPipe:
path = (pathlib.Path(root) / self.file_name).expanduser().resolve()
# FIXME
return FileLoader(IterableWrapper((str(path),)))
# TODO: add support for mirrors
# TODO: add support for http -> https
class HttpResource(OnlineResource):
def __init__(
self, url: str, *, sha256: str, file_name: Optional[str] = None
) -> None:
if not file_name:
file_name = os.path.basename(urlparse(url).path)
super().__init__(url, sha256=sha256, file_name=file_name)
class GDriveResource(OnlineResource):
def __init__(self, id: str, *, sha256: str, file_name: str) -> None:
# TODO: can we maybe do a head request to extract the file name?
url = f"https://drive.google.com/file/d/{id}/view"
super().__init__(url, sha256=sha256, file_name=file_name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment