Unverified Commit 1ac6e8b9 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Refactor and simplify prototype datasets (#5778)



* refactor prototype datasets to inherit from IterDataPipe (#5448)

* refactor prototype datasets to inherit from IterDataPipe

* depend on new architecture

* fix missing file detection

* remove unrelated file

* reinstante decorator for mock registering

* options -> config

* remove passing of info to mock data functions

* refactor categories file generation

* fix imagenet

* fix prototype datasets data loading tests (#5711)

* reenable serialization test

* cleanup

* fix dill test

* trigger CI

* patch DILL_AVAILABLE for pickle serialization

* revert CI changes

* remove dill test and traversable test

* add data loader test

* parametrize over only_datapipe

* draw one sample rather than exhaust data loader

* cleanup

* trigger CI

* migrate VOC prototype dataset (#5743)

* migrate VOC prototype dataset

* cleanup

* revert unrelated mock data changes

* remove categories annotations

* move properties to constructor

* readd homepage

* migrate CIFAR prototype datasets (#5751)

* migrate country211 prototype dataset (#5753)

* migrate CLEVR prototype datsaet (#5752)

* migrate coco prototype (#5473)

* migrate coco prototype

* revert unrelated change

* add kwargs to super constructor call

* remove unneeded changes

* fix docstring position

* make kwargs explicit

* add dependencies to docstring

* fix missing dependency message

* Migrate PCAM prototype dataset (#5745)

* Port PCAM

* skip_integrity_check

* Update torchvision/prototype/datasets/_builtin/pcam.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Address comments
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate DTD prototype dataset (#5757)

* Migrate DTD prototype dataset

* Docstring

* Apply suggestions from code review
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate GTSRB prototype dataset (#5746)

* Migrate GTSRB prototype dataset

* ufmt

* Address comments

* Apparently mypy doesn't know that __len__ returns ints. How cute.

* why is the CI not triggered??

* Update torchvision/prototype/datasets/_builtin/gtsrb.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* migrate CelebA prototype dataset (#5750)

* migrate CelebA prototype dataset

* inline split_id

* Migrate Food101 prototype dataset (#5758)

* Migrate Food101 dataset

* Added length

* Update torchvision/prototype/datasets/_builtin/food101.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate Fer2013 prototype dataset (#5759)

* Migrate Fer2013 prototype dataset

* Update torchvision/prototype/datasets/_builtin/fer2013.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate EuroSAT prototype dataset (#5760)

* Migrate Semeion prototype dataset (#5761)

* migrate caltech prototype datasets (#5749)

* migrate caltech prototype datasets

* resolve third party dependencies

* Migrate Oxford Pets prototype dataset (#5764)

* Migrate Oxford Pets prototype dataset

* Update torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* migrate mnist prototype datasets (#5480)

* migrate MNIST prototype datasets

* Update torchvision/prototype/datasets/_builtin/mnist.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Migrate Stanford Cars prototype dataset (#5767)

* Migrate Stanford Cars prototype dataset

* Address comments

* fix category file generation (#5770)

* fix category file generation

* revert unrelated change

* revert unrelated change

* migrate cub200 prototype dataset (#5765)

* migrate cub200 prototype dataset

* address comments

* fix category-file-generation

* Migrate USPS prototype dataset (#5771)

* migrate SBD prototype dataset (#5772)

* migrate SBD prototype dataset

* reuse categories

* Migrate SVHN prototype dataset (#5769)

* add test to enforce __len__ is working on prototype datasets (#5742)

* reactivate special dataset tests

* add missing annotation

* Cleanup prototype dataset implementation (#5774)

* Remove Dataset2 class

* Move read_categories_file out of DatasetInfo

* Remove FrozenBunch and FrozenMapping

* Remove test_prototype_datasets_api.py and move missing dep test somewhere else

* ufmt

* Let read_categories_file accept names instead of paths

* Mypy

* flake8

* fix category file reading
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* update prototype dataset README (#5777)

* update prototype dataset README

* fix header level

* Apply suggestions from code review
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 5f74f031
import enum
import functools import functools
import pathlib import pathlib
import re import re
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast, Union
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
...@@ -14,23 +15,30 @@ from torchdata.datapipes.iter import ( ...@@ -14,23 +15,30 @@ from torchdata.datapipes.iter import (
Enumerator, Enumerator,
) )
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource, OnlineResource,
ManualDownloadResource, ManualDownloadResource,
Dataset,
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
path_comparator,
getitem, getitem,
read_mat, read_mat,
hint_sharding, hint_sharding,
hint_shuffling, hint_shuffling,
read_categories_file,
path_accessor,
) )
from torchvision.prototype.features import Label, EncodedImage from torchvision.prototype.features import Label, EncodedImage
from torchvision.prototype.utils._internal import FrozenMapping
from .._api import register_dataset, register_info
NAME = "imagenet"
@register_info(NAME)
def _info() -> Dict[str, Any]:
categories, wnids = zip(*read_categories_file(NAME))
return dict(categories=categories, wnids=wnids)
class ImageNetResource(ManualDownloadResource): class ImageNetResource(ManualDownloadResource):
...@@ -38,32 +46,33 @@ class ImageNetResource(ManualDownloadResource): ...@@ -38,32 +46,33 @@ class ImageNetResource(ManualDownloadResource):
super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs) super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs)
class ImageNetDemux(enum.IntEnum):
META = 0
LABEL = 1
@register_dataset(NAME)
class ImageNet(Dataset): class ImageNet(Dataset):
def _make_info(self) -> DatasetInfo: """
name = "imagenet" - **homepage**: https://www.image-net.org/
categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories")) """
return DatasetInfo(
name,
dependencies=("scipy",),
categories=categories,
homepage="https://www.image-net.org/",
valid_options=dict(split=("train", "val", "test")),
extra=dict(
wnid_to_category=FrozenMapping(zip(wnids, categories)),
category_to_wnid=FrozenMapping(zip(categories, wnids)),
sizes=FrozenMapping(
[
(DatasetConfig(split="train"), 1_281_167),
(DatasetConfig(split="val"), 50_000),
(DatasetConfig(split="test"), 100_000),
]
),
),
)
def supports_sharded(self) -> bool: def __init__(
return True self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val", "test"})
info = _info()
categories, wnids = info["categories"], info["wnids"]
self._categories = categories
self._wnids = wnids
self._wnid_to_category = dict(zip(wnids, categories))
super().__init__(root, skip_integrity_check=skip_integrity_check)
_IMAGES_CHECKSUMS = { _IMAGES_CHECKSUMS = {
"train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb", "train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
...@@ -71,15 +80,15 @@ class ImageNet(Dataset): ...@@ -71,15 +80,15 @@ class ImageNet(Dataset):
"test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4", "test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4",
} }
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
name = "test_v10102019" if config.split == "test" else config.split name = "test_v10102019" if self._split == "test" else self._split
images = ImageNetResource( images = ImageNetResource(
file_name=f"ILSVRC2012_img_{name}.tar", file_name=f"ILSVRC2012_img_{name}.tar",
sha256=self._IMAGES_CHECKSUMS[name], sha256=self._IMAGES_CHECKSUMS[name],
) )
resources: List[OnlineResource] = [images] resources: List[OnlineResource] = [images]
if config.split == "val": if self._split == "val":
devkit = ImageNetResource( devkit = ImageNetResource(
file_name="ILSVRC2012_devkit_t12.tar.gz", file_name="ILSVRC2012_devkit_t12.tar.gz",
sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953",
...@@ -88,19 +97,12 @@ class ImageNet(Dataset): ...@@ -88,19 +97,12 @@ class ImageNet(Dataset):
return resources return resources
def num_samples(self, config: DatasetConfig) -> int:
return {
"train": 1_281_167,
"val": 50_000,
"test": 100_000,
}[config.split]
_TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG") _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG")
def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]:
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"] wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"]
label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories) label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories)
return (label, wnid), data return (label, wnid), data
def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]:
...@@ -108,10 +110,17 @@ class ImageNet(Dataset): ...@@ -108,10 +110,17 @@ class ImageNet(Dataset):
def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]: def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
return { return {
"meta.mat": 0, "meta.mat": ImageNetDemux.META,
"ILSVRC2012_validation_ground_truth.txt": 1, "ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL,
}.get(pathlib.Path(data[0]).name) }.get(pathlib.Path(data[0]).name)
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
_WNID_MAP = {
"n03126707": "construction crane",
"n03710721": "tank suit",
}
def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]: def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]:
synsets = read_mat(data[1], squeeze_me=True)["synsets"] synsets = read_mat(data[1], squeeze_me=True)["synsets"]
return [ return [
...@@ -121,21 +130,20 @@ class ImageNet(Dataset): ...@@ -121,21 +130,20 @@ class ImageNet(Dataset):
if num_children == 0 if num_children == 0
] ]
def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: List[str]) -> str: def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str:
return wnids[int(imagenet_label) - 1] return wnids[int(imagenet_label) - 1]
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG") _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
def _val_test_image_key(self, data: Tuple[str, Any]) -> int: def _val_test_image_key(self, path: pathlib.Path) -> int:
path = pathlib.Path(data[0]) return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name)["id"]) # type: ignore[index]
return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr]
def _prepare_val_data( def _prepare_val_data(
self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]] self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]]
) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: ) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]:
label_data, image_data = data label_data, image_data = data
_, wnid = label_data _, wnid = label_data
label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories) label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories)
return (label, wnid), image_data return (label, wnid), image_data
def _prepare_sample( def _prepare_sample(
...@@ -150,19 +158,17 @@ class ImageNet(Dataset): ...@@ -150,19 +158,17 @@ class ImageNet(Dataset):
image=EncodedImage.from_file(buffer), image=EncodedImage.from_file(buffer),
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig if self._split in {"train", "test"}:
) -> IterDataPipe[Dict[str, Any]]:
if config.split in {"train", "test"}:
dp = resource_dps[0] dp = resource_dps[0]
# the train archive is a tar of tars # the train archive is a tar of tars
if config.split == "train": if self._split == "train":
dp = TarArchiveLoader(dp) dp = TarArchiveLoader(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = Mapper(dp, self._prepare_train_data if config.split == "train" else self._prepare_test_data) dp = Mapper(dp, self._prepare_train_data if self._split == "train" else self._prepare_test_data)
else: # config.split == "val": else: # config.split == "val":
images_dp, devkit_dp = resource_dps images_dp, devkit_dp = resource_dps
...@@ -174,6 +180,7 @@ class ImageNet(Dataset): ...@@ -174,6 +180,7 @@ class ImageNet(Dataset):
_, wnids = zip(*next(iter(meta_dp))) _, wnids = zip(*next(iter(meta_dp)))
label_dp = LineReader(label_dp, decode=True, return_path=False) label_dp = LineReader(label_dp, decode=True, return_path=False)
# We cannot use self._wnids here, since we use a different order than the dataset
label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids)) label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids))
label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
label_dp = hint_shuffling(label_dp) label_dp = hint_shuffling(label_dp)
...@@ -183,26 +190,29 @@ class ImageNet(Dataset): ...@@ -183,26 +190,29 @@ class ImageNet(Dataset):
label_dp, label_dp,
images_dp, images_dp,
key_fn=getitem(0), key_fn=getitem(0),
ref_key_fn=self._val_test_image_key, ref_key_fn=path_accessor(self._val_test_image_key),
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
dp = Mapper(dp, self._prepare_val_data) dp = Mapper(dp, self._prepare_val_data)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 def __len__(self) -> int:
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment return {
_WNID_MAP = { "train": 1_281_167,
"n03126707": "construction crane", "val": 50_000,
"n03710721": "tank suit", "test": 100_000,
} }[self._split]
def _filter_meta(self, data: Tuple[str, Any]) -> bool:
return self._classifiy_devkit(data) == ImageNetDemux.META
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]: def _generate_categories(self) -> List[Tuple[str, ...]]:
config = self.info.make_config(split="val") self._split = "val"
resources = self.resources(config) resources = self._resources()
devkit_dp = resources[1].load(root) devkit_dp = resources[1].load(self._root)
meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) meta_dp = Filter(devkit_dp, self._filter_meta)
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp))) categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp)))
......
...@@ -7,12 +7,13 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, U ...@@ -7,12 +7,13 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, U
import torch import torch
from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, Mapper, Zipper, Decompressor from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, Mapper, Zipper, Decompressor
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling
from torchvision.prototype.features import Image, Label from torchvision.prototype.features import Image, Label
from torchvision.prototype.utils._internal import fromfile from torchvision.prototype.utils._internal import fromfile
__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"] from .._api import register_dataset, register_info
prod = functools.partial(functools.reduce, operator.mul) prod = functools.partial(functools.reduce, operator.mul)
...@@ -61,14 +62,14 @@ class _MNISTBase(Dataset): ...@@ -61,14 +62,14 @@ class _MNISTBase(Dataset):
_URL_BASE: Union[str, Sequence[str]] _URL_BASE: Union[str, Sequence[str]]
@abc.abstractmethod @abc.abstractmethod
def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
pass pass
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
(images_file, images_sha256), ( (images_file, images_sha256), (
labels_file, labels_file,
labels_sha256, labels_sha256,
) = self._files_and_checksums(config) ) = self._files_and_checksums()
url_bases = self._URL_BASE url_bases = self._URL_BASE
if isinstance(url_bases, str): if isinstance(url_bases, str):
...@@ -82,21 +83,21 @@ class _MNISTBase(Dataset): ...@@ -82,21 +83,21 @@ class _MNISTBase(Dataset):
return [images, labels] return [images, labels]
def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]:
return None, None return None, None
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: _categories: List[str]
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]:
image, label = data image, label = data
return dict( return dict(
image=Image(image), image=Image(image),
label=Label(label, dtype=torch.int64, categories=self.categories), label=Label(label, dtype=torch.int64, categories=self._categories),
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
images_dp, labels_dp = resource_dps images_dp, labels_dp = resource_dps
start, stop = self.start_and_stop(config) start, stop = self.start_and_stop()
images_dp = Decompressor(images_dp) images_dp = Decompressor(images_dp)
images_dp = MNISTFileReader(images_dp, start=start, stop=stop) images_dp = MNISTFileReader(images_dp, start=start, stop=stop)
...@@ -107,19 +108,31 @@ class _MNISTBase(Dataset): ...@@ -107,19 +108,31 @@ class _MNISTBase(Dataset):
dp = Zipper(images_dp, labels_dp) dp = Zipper(images_dp, labels_dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
return Mapper(dp, functools.partial(self._prepare_sample, config=config)) return Mapper(dp, self._prepare_sample)
@register_info("mnist")
def _mnist_info() -> Dict[str, Any]:
return dict(
categories=[str(label) for label in range(10)],
)
@register_dataset("mnist")
class MNIST(_MNISTBase): class MNIST(_MNISTBase):
def _make_info(self) -> DatasetInfo: """
return DatasetInfo( - **homepage**: http://yann.lecun.com/exdb/mnist
"mnist", """
categories=10,
homepage="http://yann.lecun.com/exdb/mnist", def __init__(
valid_options=dict( self,
split=("train", "test"), root: Union[str, pathlib.Path],
), *,
) split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL_BASE: Union[str, Sequence[str]] = ( _URL_BASE: Union[str, Sequence[str]] = (
"http://yann.lecun.com/exdb/mnist", "http://yann.lecun.com/exdb/mnist",
...@@ -132,8 +145,8 @@ class MNIST(_MNISTBase): ...@@ -132,8 +145,8 @@ class MNIST(_MNISTBase):
"t10k-labels-idx1-ubyte.gz": "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6", "t10k-labels-idx1-ubyte.gz": "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6",
} }
def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
prefix = "train" if config.split == "train" else "t10k" prefix = "train" if self._split == "train" else "t10k"
images_file = f"{prefix}-images-idx3-ubyte.gz" images_file = f"{prefix}-images-idx3-ubyte.gz"
labels_file = f"{prefix}-labels-idx1-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz"
return (images_file, self._CHECKSUMS[images_file]), ( return (images_file, self._CHECKSUMS[images_file]), (
...@@ -141,28 +154,35 @@ class MNIST(_MNISTBase): ...@@ -141,28 +154,35 @@ class MNIST(_MNISTBase):
self._CHECKSUMS[labels_file], self._CHECKSUMS[labels_file],
) )
_categories = _mnist_info()["categories"]
def __len__(self) -> int:
return 60_000 if self._split == "train" else 10_000
@register_info("fashionmnist")
def _fashionmnist_info() -> Dict[str, Any]:
return dict(
categories=[
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
],
)
@register_dataset("fashionmnist")
class FashionMNIST(MNIST): class FashionMNIST(MNIST):
def _make_info(self) -> DatasetInfo: """
return DatasetInfo( - **homepage**: https://github.com/zalandoresearch/fashion-mnist
"fashionmnist", """
categories=(
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
),
homepage="https://github.com/zalandoresearch/fashion-mnist",
valid_options=dict(
split=("train", "test"),
),
)
_URL_BASE = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com" _URL_BASE = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com"
_CHECKSUMS = { _CHECKSUMS = {
...@@ -172,17 +192,21 @@ class FashionMNIST(MNIST): ...@@ -172,17 +192,21 @@ class FashionMNIST(MNIST):
"t10k-labels-idx1-ubyte.gz": "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5", "t10k-labels-idx1-ubyte.gz": "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5",
} }
_categories = _fashionmnist_info()["categories"]
@register_info("kmnist")
def _kmnist_info() -> Dict[str, Any]:
return dict(
categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"],
)
@register_dataset("kmnist")
class KMNIST(MNIST): class KMNIST(MNIST):
def _make_info(self) -> DatasetInfo: """
return DatasetInfo( - **homepage**: http://codh.rois.ac.jp/kmnist/index.html.en
"kmnist", """
categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"],
homepage="http://codh.rois.ac.jp/kmnist/index.html.en",
valid_options=dict(
split=("train", "test"),
),
)
_URL_BASE = "http://codh.rois.ac.jp/kmnist/dataset/kmnist" _URL_BASE = "http://codh.rois.ac.jp/kmnist/dataset/kmnist"
_CHECKSUMS = { _CHECKSUMS = {
...@@ -192,36 +216,46 @@ class KMNIST(MNIST): ...@@ -192,36 +216,46 @@ class KMNIST(MNIST):
"t10k-labels-idx1-ubyte.gz": "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c", "t10k-labels-idx1-ubyte.gz": "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c",
} }
_categories = _kmnist_info()["categories"]
@register_info("emnist")
def _emnist_info() -> Dict[str, Any]:
return dict(
categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase),
)
@register_dataset("emnist")
class EMNIST(_MNISTBase): class EMNIST(_MNISTBase):
def _make_info(self) -> DatasetInfo: """
return DatasetInfo( - **homepage**: https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist
"emnist", """
categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase),
homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist", def __init__(
valid_options=dict( self,
split=("train", "test"), root: Union[str, pathlib.Path],
image_set=( *,
"Balanced", split: str = "train",
"By_Merge", image_set: str = "Balanced",
"By_Class", skip_integrity_check: bool = False,
"Letters", ) -> None:
"Digits", self._split = self._verify_str_arg(split, "split", ("train", "test"))
"MNIST", self._image_set = self._verify_str_arg(
), image_set, "image_set", ("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST")
),
) )
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL_BASE = "https://rds.westernsydney.edu.au/Institutes/MARCS/BENS/EMNIST" _URL_BASE = "https://rds.westernsydney.edu.au/Institutes/MARCS/BENS/EMNIST"
def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
prefix = f"emnist-{config.image_set.replace('_', '').lower()}-{config.split}" prefix = f"emnist-{self._image_set.replace('_', '').lower()}-{self._split}"
images_file = f"{prefix}-images-idx3-ubyte.gz" images_file = f"{prefix}-images-idx3-ubyte.gz"
labels_file = f"{prefix}-labels-idx1-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz"
# Since EMNIST provides the data files inside an archive, we don't need provide checksums for them # Since EMNIST provides the data files inside an archive, we don't need to provide checksums for them
return (images_file, ""), (labels_file, "") return (images_file, ""), (labels_file, "")
def resources(self, config: Optional[DatasetConfig] = None) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
return [ return [
HttpResource( HttpResource(
f"{self._URL_BASE}/emnist-gzip.zip", f"{self._URL_BASE}/emnist-gzip.zip",
...@@ -229,9 +263,9 @@ class EMNIST(_MNISTBase): ...@@ -229,9 +263,9 @@ class EMNIST(_MNISTBase):
) )
] ]
def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]: def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
(images_file, _), (labels_file, _) = self._files_and_checksums(config) (images_file, _), (labels_file, _) = self._files_and_checksums()
if path.name == images_file: if path.name == images_file:
return 0 return 0
elif path.name == labels_file: elif path.name == labels_file:
...@@ -239,6 +273,8 @@ class EMNIST(_MNISTBase): ...@@ -239,6 +273,8 @@ class EMNIST(_MNISTBase):
else: else:
return None return None
_categories = _emnist_info()["categories"]
_LABEL_OFFSETS = { _LABEL_OFFSETS = {
38: 1, 38: 1,
39: 1, 39: 1,
...@@ -251,45 +287,71 @@ class EMNIST(_MNISTBase): ...@@ -251,45 +287,71 @@ class EMNIST(_MNISTBase):
46: 9, 46: 9,
} }
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]:
# In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper).
# That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense,
# i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example, # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create these gaps. For
# since there is no 'c', 'd' corresponds to # example, since there is no 'c', 'd' corresponds to
# label 38 (10 digits + 26 uppercase letters + 3rd unmerged lower case letter - 1 for zero indexing), # label 38 (10 digits + 26 uppercase letters + 3rd unmerged lower case letter - 1 for zero indexing),
# and at the same time corresponds to # and at the same time corresponds to
# index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing) # index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing)
# in self.categories. Thus, we need to add 1 to the label to correct this. # in self._categories. Thus, we need to add 1 to the label to correct this.
if config.image_set in ("Balanced", "By_Merge"): if self._image_set in ("Balanced", "By_Merge"):
image, label = data image, label = data
label += self._LABEL_OFFSETS.get(int(label), 0) label += self._LABEL_OFFSETS.get(int(label), 0)
data = (image, label) data = (image, label)
return super()._prepare_sample(data, config=config) return super()._prepare_sample(data)
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0] archive_dp = resource_dps[0]
images_dp, labels_dp = Demultiplexer( images_dp, labels_dp = Demultiplexer(
archive_dp, archive_dp,
2, 2,
functools.partial(self._classify_archive, config=config), self._classify_archive,
drop_none=True, drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
return super()._make_datapipe([images_dp, labels_dp], config=config) return super()._datapipe([images_dp, labels_dp])
def __len__(self) -> int:
return {
("train", "Balanced"): 112_800,
("train", "By_Merge"): 697_932,
("train", "By_Class"): 697_932,
("train", "Letters"): 124_800,
("train", "Digits"): 240_000,
("train", "MNIST"): 60_000,
("test", "Balanced"): 18_800,
("test", "By_Merge"): 116_323,
("test", "By_Class"): 116_323,
("test", "Letters"): 20_800,
("test", "Digits"): 40_000,
("test", "MNIST"): 10_000,
}[(self._split, self._image_set)]
@register_info("qmnist")
def _qmnist_info() -> Dict[str, Any]:
return dict(
categories=[str(label) for label in range(10)],
)
@register_dataset("qmnist")
class QMNIST(_MNISTBase): class QMNIST(_MNISTBase):
def _make_info(self) -> DatasetInfo: """
return DatasetInfo( - **homepage**: https://github.com/facebookresearch/qmnist
"qmnist", """
categories=10,
homepage="https://github.com/facebookresearch/qmnist", def __init__(
valid_options=dict( self,
split=("train", "test", "test10k", "test50k", "nist"), root: Union[str, pathlib.Path],
), *,
) split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test", "test10k", "test50k", "nist"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL_BASE = "https://raw.githubusercontent.com/facebookresearch/qmnist/master" _URL_BASE = "https://raw.githubusercontent.com/facebookresearch/qmnist/master"
_CHECKSUMS = { _CHECKSUMS = {
...@@ -301,9 +363,9 @@ class QMNIST(_MNISTBase): ...@@ -301,9 +363,9 @@ class QMNIST(_MNISTBase):
"xnist-labels-idx2-int.xz": "db042968723ec2b7aed5f1beac25d2b6e983b9286d4f4bf725f1086e5ae55c4f", "xnist-labels-idx2-int.xz": "db042968723ec2b7aed5f1beac25d2b6e983b9286d4f4bf725f1086e5ae55c4f",
} }
def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
prefix = "xnist" if config.split == "nist" else f"qmnist-{'train' if config.split== 'train' else 'test'}" prefix = "xnist" if self._split == "nist" else f"qmnist-{'train' if self._split == 'train' else 'test'}"
suffix = "xz" if config.split == "nist" else "gz" suffix = "xz" if self._split == "nist" else "gz"
images_file = f"{prefix}-images-idx3-ubyte.{suffix}" images_file = f"{prefix}-images-idx3-ubyte.{suffix}"
labels_file = f"{prefix}-labels-idx2-int.{suffix}" labels_file = f"{prefix}-labels-idx2-int.{suffix}"
return (images_file, self._CHECKSUMS[images_file]), ( return (images_file, self._CHECKSUMS[images_file]), (
...@@ -311,13 +373,13 @@ class QMNIST(_MNISTBase): ...@@ -311,13 +373,13 @@ class QMNIST(_MNISTBase):
self._CHECKSUMS[labels_file], self._CHECKSUMS[labels_file],
) )
def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]:
start: Optional[int] start: Optional[int]
stop: Optional[int] stop: Optional[int]
if config.split == "test10k": if self._split == "test10k":
start = 0 start = 0
stop = 10000 stop = 10000
elif config.split == "test50k": elif self._split == "test50k":
start = 10000 start = 10000
stop = None stop = None
else: else:
...@@ -325,10 +387,12 @@ class QMNIST(_MNISTBase): ...@@ -325,10 +387,12 @@ class QMNIST(_MNISTBase):
return start, stop return start, stop
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: _categories = _emnist_info()["categories"]
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]:
image, ann = data image, ann = data
label, *extra_anns = ann label, *extra_anns = ann
sample = super()._prepare_sample((image, label), config=config) sample = super()._prepare_sample((image, label))
sample.update( sample.update(
dict( dict(
...@@ -340,3 +404,12 @@ class QMNIST(_MNISTBase): ...@@ -340,3 +404,12 @@ class QMNIST(_MNISTBase):
) )
sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]]))) sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]])))
return sample return sample
def __len__(self) -> int:
return {
"train": 60_000,
"test": 60_000,
"test10k": 10_000,
"test50k": 50_000,
"nist": 402_953,
}[self._split]
import enum import enum
import pathlib import pathlib
from typing import Any, Dict, List, Optional, Tuple, BinaryIO from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetConfig,
DatasetInfo,
HttpResource, HttpResource,
OnlineResource, OnlineResource,
) )
...@@ -16,27 +14,41 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -16,27 +14,41 @@ from torchvision.prototype.datasets.utils._internal import (
hint_shuffling, hint_shuffling,
getitem, getitem,
path_accessor, path_accessor,
read_categories_file,
path_comparator, path_comparator,
) )
from torchvision.prototype.features import Label, EncodedImage from torchvision.prototype.features import Label, EncodedImage
from .._api import register_dataset, register_info
class OxfordIITPetDemux(enum.IntEnum):
NAME = "oxford-iiit-pet"
class OxfordIIITPetDemux(enum.IntEnum):
SPLIT_AND_CLASSIFICATION = 0 SPLIT_AND_CLASSIFICATION = 0
SEGMENTATIONS = 1 SEGMENTATIONS = 1
class OxfordIITPet(Dataset): @register_info(NAME)
def _make_info(self) -> DatasetInfo: def _info() -> Dict[str, Any]:
return DatasetInfo( return dict(categories=read_categories_file(NAME))
"oxford-iiit-pet",
homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/",
valid_options=dict(
split=("trainval", "test"),
),
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]: @register_dataset(NAME)
class OxfordIIITPet(Dataset):
"""Oxford IIIT Pet Dataset
homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/",
"""
def __init__(
self, root: Union[str, pathlib.Path], *, split: str = "trainval", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", {"trainval", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
images = HttpResource( images = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d", sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d",
...@@ -51,8 +63,8 @@ class OxfordIITPet(Dataset): ...@@ -51,8 +63,8 @@ class OxfordIITPet(Dataset):
def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]: def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]:
return { return {
"annotations": OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION, "annotations": OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION,
"trimaps": OxfordIITPetDemux.SEGMENTATIONS, "trimaps": OxfordIIITPetDemux.SEGMENTATIONS,
}.get(pathlib.Path(data[0]).parent.name) }.get(pathlib.Path(data[0]).parent.name)
def _filter_images(self, data: Tuple[str, Any]) -> bool: def _filter_images(self, data: Tuple[str, Any]) -> bool:
...@@ -70,7 +82,7 @@ class OxfordIITPet(Dataset): ...@@ -70,7 +82,7 @@ class OxfordIITPet(Dataset):
image_path, image_buffer = image_data image_path, image_buffer = image_data
return dict( return dict(
label=Label(int(classification_data["label"]) - 1, categories=self.categories), label=Label(int(classification_data["label"]) - 1, categories=self._categories),
species="cat" if classification_data["species"] == "1" else "dog", species="cat" if classification_data["species"] == "1" else "dog",
segmentation_path=segmentation_path, segmentation_path=segmentation_path,
segmentation=EncodedImage.from_file(segmentation_buffer), segmentation=EncodedImage.from_file(segmentation_buffer),
...@@ -78,9 +90,7 @@ class OxfordIITPet(Dataset): ...@@ -78,9 +90,7 @@ class OxfordIITPet(Dataset):
image=EncodedImage.from_file(image_buffer), image=EncodedImage.from_file(image_buffer),
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps images_dp, anns_dp = resource_dps
images_dp = Filter(images_dp, self._filter_images) images_dp = Filter(images_dp, self._filter_images)
...@@ -93,9 +103,7 @@ class OxfordIITPet(Dataset): ...@@ -93,9 +103,7 @@ class OxfordIITPet(Dataset):
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
split_and_classification_dp = Filter( split_and_classification_dp = Filter(split_and_classification_dp, path_comparator("name", f"{self._split}.txt"))
split_and_classification_dp, path_comparator("name", f"{config.split}.txt")
)
split_and_classification_dp = CSVDictParser( split_and_classification_dp = CSVDictParser(
split_and_classification_dp, fieldnames=("image_id", "label", "species"), delimiter=" " split_and_classification_dp, fieldnames=("image_id", "label", "species"), delimiter=" "
) )
...@@ -122,15 +130,14 @@ class OxfordIITPet(Dataset): ...@@ -122,15 +130,14 @@ class OxfordIITPet(Dataset):
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool: def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION return self._classify_anns(data) == OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self) -> List[str]:
config = self.default_config resources = self._resources()
resources = self.resources(config)
dp = resources[1].load(root) dp = resources[1].load(self._root)
dp = Filter(dp, self._filter_split_and_classification_anns) dp = Filter(dp, self._filter_split_and_classification_anns)
dp = Filter(dp, path_comparator("name", f"{config.split}.txt")) dp = Filter(dp, path_comparator("name", "trainval.txt"))
dp = CSVDictParser(dp, fieldnames=("image_id", "label"), delimiter=" ") dp = CSVDictParser(dp, fieldnames=("image_id", "label"), delimiter=" ")
raw_categories_and_labels = {(data["image_id"].rsplit("_", 1)[0], data["label"]) for data in dp} raw_categories_and_labels = {(data["image_id"].rsplit("_", 1)[0], data["label"]) for data in dp}
...@@ -138,3 +145,6 @@ class OxfordIITPet(Dataset): ...@@ -138,3 +145,6 @@ class OxfordIITPet(Dataset):
*sorted(raw_categories_and_labels, key=lambda raw_category_and_label: int(raw_category_and_label[1])) *sorted(raw_categories_and_labels, key=lambda raw_category_and_label: int(raw_category_and_label[1]))
) )
return [" ".join(part.title() for part in raw_category.split("_")) for raw_category in raw_categories] return [" ".join(part.title() for part in raw_category.split("_")) for raw_category in raw_categories]
def __len__(self) -> int:
return 3_680 if self._split == "trainval" else 3_669
import io import io
import pathlib
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, List, Optional, Tuple, Iterator from typing import Any, Dict, List, Optional, Tuple, Iterator, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource, OnlineResource,
GDriveResource, GDriveResource,
) )
...@@ -17,6 +16,11 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -17,6 +16,11 @@ from torchvision.prototype.datasets.utils._internal import (
) )
from torchvision.prototype.features import Label from torchvision.prototype.features import Label
from .._api import register_dataset, register_info
NAME = "pcam"
class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]): class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]):
def __init__( def __init__(
...@@ -40,15 +44,25 @@ class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]): ...@@ -40,15 +44,25 @@ class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]):
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256")) _Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=["0", "1"])
@register_dataset(NAME)
class PCAM(Dataset): class PCAM(Dataset):
def _make_info(self) -> DatasetInfo: # TODO write proper docstring
return DatasetInfo( """PCAM Dataset
"pcam",
homepage="https://github.com/basveeling/pcam", homepage="https://github.com/basveeling/pcam"
categories=2, """
valid_options=dict(split=("train", "test", "val")),
dependencies=["h5py"], def __init__(
) self, root: Union[str, pathlib.Path], split: str = "train", *, skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("h5py",))
_RESOURCES = { _RESOURCES = {
"train": ( "train": (
...@@ -89,10 +103,10 @@ class PCAM(Dataset): ...@@ -89,10 +103,10 @@ class PCAM(Dataset):
), ),
} }
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
return [ # = [images resource, targets resource] return [ # = [images resource, targets resource]
GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, preprocess="decompress") GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, preprocess="decompress")
for file_name, gdrive_id, sha256 in self._RESOURCES[config.split] for file_name, gdrive_id, sha256 in self._RESOURCES[self._split]
] ]
def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]: def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]:
...@@ -100,12 +114,10 @@ class PCAM(Dataset): ...@@ -100,12 +114,10 @@ class PCAM(Dataset):
return { return {
"image": features.Image(image.transpose(2, 0, 1)), "image": features.Image(image.transpose(2, 0, 1)),
"label": Label(target.item()), "label": Label(target.item(), categories=self._categories),
} }
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
images_dp, targets_dp = resource_dps images_dp, targets_dp = resource_dps
...@@ -116,3 +128,6 @@ class PCAM(Dataset): ...@@ -116,3 +128,6 @@ class PCAM(Dataset):
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 262_144 if self._split == "train" else 32_768
import pathlib import pathlib
import re import re
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union
import numpy as np import numpy as np
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
...@@ -11,13 +11,7 @@ from torchdata.datapipes.iter import ( ...@@ -11,13 +11,7 @@ from torchdata.datapipes.iter import (
IterKeyZipper, IterKeyZipper,
LineReader, LineReader,
) )
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
read_mat, read_mat,
...@@ -26,22 +20,42 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -26,22 +20,42 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator, path_comparator,
hint_sharding, hint_sharding,
hint_shuffling, hint_shuffling,
read_categories_file,
) )
from torchvision.prototype.features import _Feature, EncodedImage from torchvision.prototype.features import _Feature, EncodedImage
from .._api import register_dataset, register_info
NAME = "sbd"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class SBD(Dataset): class SBD(Dataset):
def _make_info(self) -> DatasetInfo: """
return DatasetInfo( - **homepage**: http://home.bharathh.info/pubs/codes/SBD/download.html
"sbd", - **dependencies**:
dependencies=("scipy",), - <scipy `https://scipy.org`>_
homepage="http://home.bharathh.info/pubs/codes/SBD/download.html", """
valid_options=dict(
split=("train", "val", "train_noval"), def __init__(
), self,
) root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "train_noval"))
self._categories = _info()["categories"]
super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check)
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
archive = HttpResource( archive = HttpResource(
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz", "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz",
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53", sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53",
...@@ -85,12 +99,7 @@ class SBD(Dataset): ...@@ -85,12 +99,7 @@ class SBD(Dataset):
segmentation=_Feature(anns["Segmentation"].item()), segmentation=_Feature(anns["Segmentation"].item()),
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
archive_dp, extra_split_dp = resource_dps archive_dp, extra_split_dp = resource_dps
archive_dp = resource_dps[0] archive_dp = resource_dps[0]
...@@ -101,10 +110,10 @@ class SBD(Dataset): ...@@ -101,10 +110,10 @@ class SBD(Dataset):
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True, drop_none=True,
) )
if config.split == "train_noval": if self._split == "train_noval":
split_dp = extra_split_dp split_dp = extra_split_dp
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True) split_dp = LineReader(split_dp, decode=True)
split_dp = hint_shuffling(split_dp) split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp) split_dp = hint_sharding(split_dp)
...@@ -120,10 +129,17 @@ class SBD(Dataset): ...@@ -120,10 +129,17 @@ class SBD(Dataset):
) )
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: def __len__(self) -> int:
resources = self.resources(self.default_config) return {
"train": 8_498,
"val": 2_857,
"train_noval": 5_623,
}[self._split]
def _generate_categories(self) -> Tuple[str, ...]:
resources = self._resources()
dp = resources[0].load(root) dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", "category_names.m")) dp = Filter(dp, path_comparator("name", "category_names.m"))
dp = LineReader(dp) dp = LineReader(dp)
dp = Mapper(dp, bytes.decode, input_col=1) dp = Mapper(dp, bytes.decode, input_col=1)
......
from typing import Any, Dict, List, Tuple import pathlib
from typing import Any, Dict, List, Tuple, Union
import torch import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
...@@ -8,24 +9,34 @@ from torchdata.datapipes.iter import ( ...@@ -8,24 +9,34 @@ from torchdata.datapipes.iter import (
) )
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetConfig,
DatasetInfo,
HttpResource, HttpResource,
OnlineResource, OnlineResource,
) )
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import Image, OneHotLabel from torchvision.prototype.features import Image, OneHotLabel
from .._api import register_dataset, register_info
NAME = "semeion"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=[str(i) for i in range(10)])
@register_dataset(NAME)
class SEMEION(Dataset): class SEMEION(Dataset):
def _make_info(self) -> DatasetInfo: """Semeion dataset
return DatasetInfo( homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit",
"semeion", """
categories=10,
homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None:
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]: self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
data = HttpResource( data = HttpResource(
"http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data", "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data",
sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1", sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1",
...@@ -36,18 +47,16 @@ class SEMEION(Dataset): ...@@ -36,18 +47,16 @@ class SEMEION(Dataset):
image_data, label_data = data[:256], data[256:-1] image_data, label_data = data[:256], data[256:-1]
return dict( return dict(
image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.uint8).reshape(16, 16)), image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.float).reshape(16, 16)),
label=OneHotLabel([int(label) for label in label_data], categories=self.categories), label=OneHotLabel([int(label) for label in label_data], categories=self._categories),
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = CSVParser(dp, delimiter=" ") dp = CSVParser(dp, delimiter=" ")
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 1_593
import pathlib import pathlib
from typing import Any, Dict, List, Tuple, Iterator, BinaryIO from typing import Any, Dict, List, Tuple, Iterator, BinaryIO, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, path_comparator, read_mat from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
path_comparator,
read_mat,
read_categories_file,
)
from torchvision.prototype.features import BoundingBox, EncodedImage, Label from torchvision.prototype.features import BoundingBox, EncodedImage, Label
from .._api import register_dataset, register_info
class StanfordCarsLabelReader(IterDataPipe[Tuple[int, int, int, int, int, str]]): class StanfordCarsLabelReader(IterDataPipe[Tuple[int, int, int, int, int, str]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]]) -> None: def __init__(self, datapipe: IterDataPipe[Dict[str, Any]]) -> None:
...@@ -18,16 +26,31 @@ class StanfordCarsLabelReader(IterDataPipe[Tuple[int, int, int, int, int, str]]) ...@@ -18,16 +26,31 @@ class StanfordCarsLabelReader(IterDataPipe[Tuple[int, int, int, int, int, str]])
yield tuple(ann) # type: ignore[misc] yield tuple(ann) # type: ignore[misc]
NAME = "stanford-cars"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class StanfordCars(Dataset): class StanfordCars(Dataset):
def _make_info(self) -> DatasetInfo: """Stanford Cars dataset.
return DatasetInfo( homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html",
name="stanford-cars", dependencies=scipy
homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html", """
dependencies=("scipy",),
valid_options=dict( def __init__(
split=("test", "train"), self,
), root: Union[str, pathlib.Path],
) *,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",))
_URL_ROOT = "https://ai.stanford.edu/~jkrause/" _URL_ROOT = "https://ai.stanford.edu/~jkrause/"
_URLS = { _URLS = {
...@@ -44,9 +67,9 @@ class StanfordCars(Dataset): ...@@ -44,9 +67,9 @@ class StanfordCars(Dataset):
"car_devkit": "512b227b30e2f0a8aab9e09485786ab4479582073a144998da74d64b801fd288", "car_devkit": "512b227b30e2f0a8aab9e09485786ab4479582073a144998da74d64b801fd288",
} }
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
resources: List[OnlineResource] = [HttpResource(self._URLS[config.split], sha256=self._CHECKSUM[config.split])] resources: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUM[self._split])]
if config.split == "train": if self._split == "train":
resources.append(HttpResource(url=self._URLS["car_devkit"], sha256=self._CHECKSUM["car_devkit"])) resources.append(HttpResource(url=self._URLS["car_devkit"], sha256=self._CHECKSUM["car_devkit"]))
else: else:
...@@ -65,19 +88,14 @@ class StanfordCars(Dataset): ...@@ -65,19 +88,14 @@ class StanfordCars(Dataset):
return dict( return dict(
path=path, path=path,
image=image, image=image,
label=Label(target[4] - 1, categories=self.categories), label=Label(target[4] - 1, categories=self._categories),
bounding_box=BoundingBox(target[:4], format="xyxy", image_size=image.image_size), bounding_box=BoundingBox(target[:4], format="xyxy", image_size=image.image_size),
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
images_dp, targets_dp = resource_dps images_dp, targets_dp = resource_dps
if config.split == "train": if self._split == "train":
targets_dp = Filter(targets_dp, path_comparator("name", "cars_train_annos.mat")) targets_dp = Filter(targets_dp, path_comparator("name", "cars_train_annos.mat"))
targets_dp = StanfordCarsLabelReader(targets_dp) targets_dp = StanfordCarsLabelReader(targets_dp)
dp = Zipper(images_dp, targets_dp) dp = Zipper(images_dp, targets_dp)
...@@ -85,12 +103,14 @@ class StanfordCars(Dataset): ...@@ -85,12 +103,14 @@ class StanfordCars(Dataset):
dp = hint_sharding(dp) dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self) -> List[str]:
config = self.info.make_config(split="train") resources = self._resources()
resources = self.resources(config)
devkit_dp = resources[1].load(root) devkit_dp = resources[1].load(self._root)
meta_dp = Filter(devkit_dp, path_comparator("name", "cars_meta.mat")) meta_dp = Filter(devkit_dp, path_comparator("name", "cars_meta.mat"))
_, meta_file = next(iter(meta_dp)) _, meta_file = next(iter(meta_dp))
return list(read_mat(meta_file, squeeze_me=True)["class_names"]) return list(read_mat(meta_file, squeeze_me=True)["class_names"])
def __len__(self) -> int:
return 8_144 if self._split == "train" else 8_041
from typing import Any, Dict, List, Tuple, BinaryIO import pathlib
from typing import Any, Dict, List, Tuple, BinaryIO, Union
import numpy as np import numpy as np
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
...@@ -8,8 +9,6 @@ from torchdata.datapipes.iter import ( ...@@ -8,8 +9,6 @@ from torchdata.datapipes.iter import (
) )
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetConfig,
DatasetInfo,
HttpResource, HttpResource,
OnlineResource, OnlineResource,
) )
...@@ -20,16 +19,33 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -20,16 +19,33 @@ from torchvision.prototype.datasets.utils._internal import (
) )
from torchvision.prototype.features import Label, Image from torchvision.prototype.features import Label, Image
from .._api import register_dataset, register_info
NAME = "svhn"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=[str(c) for c in range(10)])
@register_dataset(NAME)
class SVHN(Dataset): class SVHN(Dataset):
def _make_info(self) -> DatasetInfo: """SVHN Dataset.
return DatasetInfo( homepage="http://ufldl.stanford.edu/housenumbers/",
"svhn", dependencies = scipy
dependencies=("scipy",), """
categories=10,
homepage="http://ufldl.stanford.edu/housenumbers/", def __init__(
valid_options=dict(split=("train", "test", "extra")), self,
) root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test", "extra"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",))
_CHECKSUMS = { _CHECKSUMS = {
"train": "435e94d69a87fde4fd4d7f3dd208dfc32cb6ae8af2240d066de1df7508d083b8", "train": "435e94d69a87fde4fd4d7f3dd208dfc32cb6ae8af2240d066de1df7508d083b8",
...@@ -37,10 +53,10 @@ class SVHN(Dataset): ...@@ -37,10 +53,10 @@ class SVHN(Dataset):
"extra": "a133a4beb38a00fcdda90c9489e0c04f900b660ce8a316a5e854838379a71eb3", "extra": "a133a4beb38a00fcdda90c9489e0c04f900b660ce8a316a5e854838379a71eb3",
} }
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
data = HttpResource( data = HttpResource(
f"http://ufldl.stanford.edu/housenumbers/{config.split}_32x32.mat", f"http://ufldl.stanford.edu/housenumbers/{self._split}_32x32.mat",
sha256=self._CHECKSUMS[config.split], sha256=self._CHECKSUMS[self._split],
) )
return [data] return [data]
...@@ -60,18 +76,20 @@ class SVHN(Dataset): ...@@ -60,18 +76,20 @@ class SVHN(Dataset):
return dict( return dict(
image=Image(image_array.transpose((2, 0, 1))), image=Image(image_array.transpose((2, 0, 1))),
label=Label(int(label_array) % 10, categories=self.categories), label=Label(int(label_array) % 10, categories=self._categories),
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = Mapper(dp, self._read_images_and_labels) dp = Mapper(dp, self._read_images_and_labels)
dp = UnBatcher(dp) dp = UnBatcher(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
"train": 73_257,
"test": 26_032,
"extra": 531_131,
}[self._split]
from typing import Any, Dict, List import pathlib
from typing import Any, Dict, List, Union
import torch import torch
from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource from torchvision.prototype.datasets.utils import Dataset, OnlineResource, HttpResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import Image, Label from torchvision.prototype.features import Image, Label
from .._api import register_dataset, register_info
NAME = "usps"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=[str(c) for c in range(10)])
@register_dataset(NAME)
class USPS(Dataset): class USPS(Dataset):
def _make_info(self) -> DatasetInfo: """USPS Dataset
return DatasetInfo( homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps",
"usps", """
homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps",
valid_options=dict( def __init__(
split=("train", "test"), self,
), root: Union[str, pathlib.Path],
categories=10, *,
) split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass" _URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass"
...@@ -29,8 +46,8 @@ class USPS(Dataset): ...@@ -29,8 +46,8 @@ class USPS(Dataset):
), ),
} }
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
return [USPS._RESOURCES[config.split]] return [USPS._RESOURCES[self._split]]
def _prepare_sample(self, line: str) -> Dict[str, Any]: def _prepare_sample(self, line: str) -> Dict[str, Any]:
label, *values = line.strip().split(" ") label, *values = line.strip().split(" ")
...@@ -38,17 +55,15 @@ class USPS(Dataset): ...@@ -38,17 +55,15 @@ class USPS(Dataset):
pixels = torch.tensor(values).add_(1).div_(2) pixels = torch.tensor(values).add_(1).div_(2)
return dict( return dict(
image=Image(pixels.reshape(16, 16)), image=Image(pixels.reshape(16, 16)),
label=Label(int(label) - 1, categories=self.categories), label=Label(int(label) - 1, categories=self._categories),
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
dp = Decompressor(resource_dps[0]) dp = Decompressor(resource_dps[0])
dp = LineReader(dp, decode=True, return_path=False) dp = LineReader(dp, decode=True, return_path=False)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 7_291 if self._split == "train" else 2_007
import enum
import functools import functools
import pathlib import pathlib
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Callable from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Union
from xml.etree import ElementTree from xml.etree import ElementTree
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
...@@ -12,13 +13,7 @@ from torchdata.datapipes.iter import ( ...@@ -12,13 +13,7 @@ from torchdata.datapipes.iter import (
LineReader, LineReader,
) )
from torchvision.datasets import VOCDetection from torchvision.datasets import VOCDetection
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import OnlineResource, HttpResource, Dataset
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
path_accessor, path_accessor,
getitem, getitem,
...@@ -26,34 +21,48 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -26,34 +21,48 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator, path_comparator,
hint_sharding, hint_sharding,
hint_shuffling, hint_shuffling,
read_categories_file,
) )
from torchvision.prototype.features import BoundingBox, Label, EncodedImage from torchvision.prototype.features import BoundingBox, Label, EncodedImage
from .._api import register_dataset, register_info
class VOCDatasetInfo(DatasetInfo): NAME = "voc"
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._configs = tuple(config for config in self._configs if config.split != "test" or config.year == "2007")
def make_config(self, **options: Any) -> DatasetConfig:
config = super().make_config(**options)
if config.split == "test" and config.year != "2007":
raise ValueError("`split='test'` is only available for `year='2007'`")
return config @register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class VOC(Dataset): class VOC(Dataset):
def _make_info(self) -> DatasetInfo: """
return VOCDatasetInfo( - **homepage**: http://host.robots.ox.ac.uk/pascal/VOC/
"voc", """
homepage="http://host.robots.ox.ac.uk/pascal/VOC/",
valid_options=dict( def __init__(
split=("train", "val", "trainval", "test"), self,
year=("2012", "2007", "2008", "2009", "2010", "2011"), root: Union[str, pathlib.Path],
task=("detection", "segmentation"), *,
), split: str = "train",
) year: str = "2012",
task: str = "detection",
skip_integrity_check: bool = False,
) -> None:
self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012"))
if split == "test" and year != "2007":
raise ValueError("`split='test'` is only available for `year='2007'`")
else:
self._split = self._verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
self._task = self._verify_str_arg(task, "task", ("detection", "segmentation"))
self._anns_folder = "Annotations" if task == "detection" else "SegmentationClass"
self._split_folder = "Main" if task == "detection" else "Segmentation"
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_TRAIN_VAL_ARCHIVES = { _TRAIN_VAL_ARCHIVES = {
"2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"), "2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"),
...@@ -67,31 +76,27 @@ class VOC(Dataset): ...@@ -67,31 +76,27 @@ class VOC(Dataset):
"2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892") "2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892")
} }
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
file_name, sha256 = (self._TEST_ARCHIVES if config.split == "test" else self._TRAIN_VAL_ARCHIVES)[config.year] file_name, sha256 = (self._TEST_ARCHIVES if self._split == "test" else self._TRAIN_VAL_ARCHIVES)[self._year]
archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{config.year}/{file_name}", sha256=sha256) archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{self._year}/{file_name}", sha256=sha256)
return [archive] return [archive]
_ANNS_FOLDER = dict(
detection="Annotations",
segmentation="SegmentationClass",
)
_SPLIT_FOLDER = dict(
detection="Main",
segmentation="Segmentation",
)
def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) -> bool: def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) -> bool:
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
return name in path.parent.parts[-depth:] return name in path.parent.parts[-depth:]
def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]: class _Demux(enum.IntEnum):
SPLIT = 0
IMAGES = 1
ANNS = 2
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
if self._is_in_folder(data, name="ImageSets", depth=2): if self._is_in_folder(data, name="ImageSets", depth=2):
return 0 return self._Demux.SPLIT
elif self._is_in_folder(data, name="JPEGImages"): elif self._is_in_folder(data, name="JPEGImages"):
return 1 return self._Demux.IMAGES
elif self._is_in_folder(data, name=self._ANNS_FOLDER[config.task]): elif self._is_in_folder(data, name=self._anns_folder):
return 2 return self._Demux.ANNS
else: else:
return None return None
...@@ -111,7 +116,7 @@ class VOC(Dataset): ...@@ -111,7 +116,7 @@ class VOC(Dataset):
image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))), image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))),
), ),
labels=Label( labels=Label(
[self.categories.index(instance["name"]) for instance in instances], categories=self.categories [self._categories.index(instance["name"]) for instance in instances], categories=self._categories
), ),
) )
...@@ -121,8 +126,6 @@ class VOC(Dataset): ...@@ -121,8 +126,6 @@ class VOC(Dataset):
def _prepare_sample( def _prepare_sample(
self, self,
data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]], data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]],
*,
prepare_ann_fn: Callable[[BinaryIO], Dict[str, Any]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
split_and_image_data, ann_data = data split_and_image_data, ann_data = data
_, image_data = split_and_image_data _, image_data = split_and_image_data
...@@ -130,29 +133,24 @@ class VOC(Dataset): ...@@ -130,29 +133,24 @@ class VOC(Dataset):
ann_path, ann_buffer = ann_data ann_path, ann_buffer = ann_data
return dict( return dict(
prepare_ann_fn(ann_buffer), (self._prepare_detection_ann if self._task == "detection" else self._prepare_segmentation_ann)(ann_buffer),
image_path=image_path, image_path=image_path,
image=EncodedImage.from_file(image_buffer), image=EncodedImage.from_file(image_buffer),
ann_path=ann_path, ann_path=ann_path,
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0] archive_dp = resource_dps[0]
split_dp, images_dp, anns_dp = Demultiplexer( split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp, archive_dp,
3, 3,
functools.partial(self._classify_archive, config=config), self._classify_archive,
drop_none=True, drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._SPLIT_FOLDER[config.task])) split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._split_folder))
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True) split_dp = LineReader(split_dp, decode=True)
split_dp = hint_shuffling(split_dp) split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp) split_dp = hint_sharding(split_dp)
...@@ -166,25 +164,59 @@ class VOC(Dataset): ...@@ -166,25 +164,59 @@ class VOC(Dataset):
ref_key_fn=path_accessor("stem"), ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
return Mapper( return Mapper(dp, self._prepare_sample)
dp,
functools.partial( def __len__(self) -> int:
self._prepare_sample, return {
prepare_ann_fn=self._prepare_detection_ann ("train", "2007", "detection"): 2_501,
if config.task == "detection" ("train", "2007", "segmentation"): 209,
else self._prepare_segmentation_ann, ("train", "2008", "detection"): 2_111,
), ("train", "2008", "segmentation"): 511,
) ("train", "2009", "detection"): 3_473,
("train", "2009", "segmentation"): 749,
def _filter_detection_anns(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool: ("train", "2010", "detection"): 4_998,
return self._classify_archive(data, config=config) == 2 ("train", "2010", "segmentation"): 964,
("train", "2011", "detection"): 5_717,
def _generate_categories(self, root: pathlib.Path) -> List[str]: ("train", "2011", "segmentation"): 1_112,
config = self.info.make_config(task="detection") ("train", "2012", "detection"): 5_717,
("train", "2012", "segmentation"): 1_464,
resource = self.resources(config)[0] ("val", "2007", "detection"): 2_510,
dp = resource.load(pathlib.Path(root) / self.name) ("val", "2007", "segmentation"): 213,
dp = Filter(dp, self._filter_detection_anns, fn_kwargs=dict(config=config)) ("val", "2008", "detection"): 2_221,
("val", "2008", "segmentation"): 512,
("val", "2009", "detection"): 3_581,
("val", "2009", "segmentation"): 750,
("val", "2010", "detection"): 5_105,
("val", "2010", "segmentation"): 964,
("val", "2011", "detection"): 5_823,
("val", "2011", "segmentation"): 1_111,
("val", "2012", "detection"): 5_823,
("val", "2012", "segmentation"): 1_449,
("trainval", "2007", "detection"): 5_011,
("trainval", "2007", "segmentation"): 422,
("trainval", "2008", "detection"): 4_332,
("trainval", "2008", "segmentation"): 1_023,
("trainval", "2009", "detection"): 7_054,
("trainval", "2009", "segmentation"): 1_499,
("trainval", "2010", "detection"): 10_103,
("trainval", "2010", "segmentation"): 1_928,
("trainval", "2011", "detection"): 11_540,
("trainval", "2011", "segmentation"): 2_223,
("trainval", "2012", "detection"): 11_540,
("trainval", "2012", "segmentation"): 2_913,
("test", "2007", "detection"): 4_952,
("test", "2007", "segmentation"): 210,
}[(self._split, self._year, self._task)]
def _filter_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == self._Demux.ANNS
def _generate_categories(self) -> List[str]:
self._task = "detection"
resources = self._resources()
archive_dp = resources[0].load(self._root)
dp = Filter(archive_dp, self._filter_anns)
dp = Mapper(dp, self._parse_detection_ann, input_col=1) dp = Mapper(dp, self._parse_detection_ann, input_col=1)
return sorted({instance["name"] for _, anns in dp for instance in anns["object"]}) return sorted({instance["name"] for _, anns in dp for instance in anns["object"]})
...@@ -2,25 +2,21 @@ ...@@ -2,25 +2,21 @@
import argparse import argparse
import csv import csv
import pathlib
import sys import sys
from torchvision.prototype import datasets from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import find
from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR
def main(*names, force=False): def main(*names, force=False):
home = pathlib.Path(datasets.home())
for name in names: for name in names:
path = BUILTIN_DIR / f"{name}.categories" path = BUILTIN_DIR / f"{name}.categories"
if path.exists() and not force: if path.exists() and not force:
continue continue
dataset = find(name) dataset = datasets.load(name)
try: try:
categories = dataset._generate_categories(home / name) categories = dataset._generate_categories()
except NotImplementedError: except NotImplementedError:
continue continue
......
from . import _internal # usort: skip from . import _internal # usort: skip
from ._dataset import DatasetConfig, DatasetInfo, Dataset from ._dataset import Dataset
from ._query import SampleQuery from ._query import SampleQuery
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource
import abc import abc
import csv
import importlib import importlib
import itertools
import os
import pathlib import pathlib
from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection from typing import Any, Dict, List, Optional, Sequence, Union, Collection, Iterator
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
from torchvision._utils import sequence_to_str from torchvision.datasets.utils import verify_str_arg
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion
from .._home import use_sharded_dataset
from ._internal import BUILTIN_DIR, _make_sharded_datapipe
from ._resource import OnlineResource from ._resource import OnlineResource
class DatasetConfig(FrozenBunch): class Dataset(IterDataPipe[Dict[str, Any]], abc.ABC):
# This needs to be Frozen because we often pass configs as partial(func, config=config) @staticmethod
# and partial() requires the parameters to be hashable. def _verify_str_arg(
pass value: str,
arg: Optional[str] = None,
valid_values: Optional[Collection[str]] = None,
*,
custom_msg: Optional[str] = None,
) -> str:
return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg)
class DatasetInfo:
def __init__( def __init__(
self, self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False, dependencies: Collection[str] = ()
name: str,
*,
dependencies: Collection[str] = (),
categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None,
citation: Optional[str] = None,
homepage: Optional[str] = None,
license: Optional[str] = None,
valid_options: Optional[Dict[str, Sequence[Any]]] = None,
extra: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
self.name = name.lower() for dependency in dependencies:
self.dependecies = dependencies
if categories is None:
path = BUILTIN_DIR / f"{self.name}.categories"
categories = path if path.exists() else []
if isinstance(categories, int):
categories = [str(label) for label in range(categories)]
elif isinstance(categories, (str, pathlib.Path)):
path = pathlib.Path(categories).expanduser().resolve()
categories, *_ = zip(*self.read_categories_file(path))
self.categories = tuple(categories)
self.citation = citation
self.homepage = homepage
self.license = license
self._valid_options = valid_options or dict()
self._configs = tuple(
DatasetConfig(**dict(zip(self._valid_options.keys(), combination)))
for combination in itertools.product(*self._valid_options.values())
)
self.extra = FrozenBunch(extra or dict())
@property
def default_config(self) -> DatasetConfig:
return self._configs[0]
@staticmethod
def read_categories_file(path: pathlib.Path) -> List[List[str]]:
with open(path, newline="") as file:
return [row for row in csv.reader(file)]
def make_config(self, **options: Any) -> DatasetConfig:
if not self._valid_options and options:
raise ValueError(
f"Dataset {self.name} does not take any options, "
f"but got {sequence_to_str(list(options), separate_last=' and')}."
)
for name, arg in options.items():
if name not in self._valid_options:
raise ValueError(
add_suggestion(
f"Unknown option '{name}' of dataset {self.name}.",
word=name,
possibilities=sorted(self._valid_options.keys()),
)
)
valid_args = self._valid_options[name]
if arg not in valid_args:
raise ValueError(
add_suggestion(
f"Invalid argument '{arg}' for option '{name}' of dataset {self.name}.",
word=arg,
possibilities=valid_args,
)
)
return DatasetConfig(self.default_config, **options)
def check_dependencies(self) -> None:
for dependency in self.dependecies:
try: try:
importlib.import_module(dependency) importlib.import_module(dependency)
except ModuleNotFoundError as error: except ModuleNotFoundError:
raise ModuleNotFoundError( raise ModuleNotFoundError(
f"Dataset '{self.name}' depends on the third-party package '{dependency}'. " f"{type(self).__name__}() depends on the third-party package '{dependency}'. "
f"Please install it, for example with `pip install {dependency}`." f"Please install it, for example with `pip install {dependency}`."
) from error ) from None
def __repr__(self) -> str:
items = [("name", self.name)]
for key in ("citation", "homepage", "license"):
value = getattr(self, key)
if value is not None:
items.append((key, value))
items.extend(sorted((key, sequence_to_str(value)) for key, value in self._valid_options.items()))
return make_repr(type(self).__name__, items)
self._root = pathlib.Path(root).expanduser().resolve()
resources = [
resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources()
]
self._dp = self._datapipe(resources)
class Dataset(abc.ABC): def __iter__(self) -> Iterator[Dict[str, Any]]:
def __init__(self) -> None: yield from self._dp
self._info = self._make_info()
@abc.abstractmethod @abc.abstractmethod
def _make_info(self) -> DatasetInfo: def _resources(self) -> List[OnlineResource]:
pass pass
@property
def info(self) -> DatasetInfo:
return self._info
@property
def name(self) -> str:
return self.info.name
@property
def default_config(self) -> DatasetConfig:
return self.info.default_config
@property
def categories(self) -> Tuple[str, ...]:
return self.info.categories
@abc.abstractmethod @abc.abstractmethod
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def _make_datapipe( def __len__(self) -> int:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
pass pass
def supports_sharded(self) -> bool: def _generate_categories(self) -> Sequence[Union[str, Sequence[str]]]:
return False
def load(
self,
root: Union[str, pathlib.Path],
*,
config: Optional[DatasetConfig] = None,
skip_integrity_check: bool = False,
) -> IterDataPipe[Dict[str, Any]]:
if not config:
config = self.info.default_config
if use_sharded_dataset() and self.supports_sharded():
root = os.path.join(root, *config.values())
dataset_size = self.info.extra["sizes"][config]
return _make_sharded_datapipe(root, dataset_size) # type: ignore[no-any-return]
self.info.check_dependencies()
resource_dps = [
resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config)
]
return self._make_datapipe(resource_dps, config=config)
def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]:
raise NotImplementedError raise NotImplementedError
import csv
import functools import functools
import pathlib import pathlib
import pickle import pickle
...@@ -9,6 +10,7 @@ from typing import ( ...@@ -9,6 +10,7 @@ from typing import (
Any, Any,
Tuple, Tuple,
TypeVar, TypeVar,
List,
Iterator, Iterator,
Dict, Dict,
IO, IO,
...@@ -198,3 +200,11 @@ def hint_sharding(datapipe: IterDataPipe) -> ShardingFilter: ...@@ -198,3 +200,11 @@ def hint_sharding(datapipe: IterDataPipe) -> ShardingFilter:
def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]: def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]:
return Shuffler(datapipe, buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False) return Shuffler(datapipe, buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False)
def read_categories_file(name: str) -> List[Union[str, Sequence[str]]]:
path = BUILTIN_DIR / f"{name}.categories"
with open(path, newline="") as file:
rows = list(csv.reader(file))
rows = [row[0] if len(row) == 1 else row for row in rows]
return rows
...@@ -2,20 +2,13 @@ import collections.abc ...@@ -2,20 +2,13 @@ import collections.abc
import difflib import difflib
import io import io
import mmap import mmap
import os
import os.path
import platform import platform
import textwrap
from typing import ( from typing import (
Any, Any,
BinaryIO, BinaryIO,
Callable, Callable,
cast,
Collection, Collection,
Iterable,
Iterator, Iterator,
Mapping,
NoReturn,
Sequence, Sequence,
Tuple, Tuple,
TypeVar, TypeVar,
...@@ -30,9 +23,6 @@ from torchvision._utils import sequence_to_str ...@@ -30,9 +23,6 @@ from torchvision._utils import sequence_to_str
__all__ = [ __all__ = [
"add_suggestion", "add_suggestion",
"FrozenMapping",
"make_repr",
"FrozenBunch",
"fromfile", "fromfile",
"ReadOnlyTensorBuffer", "ReadOnlyTensorBuffer",
"apply_recursively", "apply_recursively",
...@@ -60,82 +50,9 @@ def add_suggestion( ...@@ -60,82 +50,9 @@ def add_suggestion(
return f"{msg.strip()} {hint}" return f"{msg.strip()} {hint}"
K = TypeVar("K")
D = TypeVar("D") D = TypeVar("D")
class FrozenMapping(Mapping[K, D]):
def __init__(self, *args: Any, **kwargs: Any) -> None:
data = dict(*args, **kwargs)
self.__dict__["__data__"] = data
self.__dict__["__final_hash__"] = hash(tuple(data.items()))
def __getitem__(self, item: K) -> D:
return cast(Mapping[K, D], self.__dict__["__data__"])[item]
def __iter__(self) -> Iterator[K]:
return iter(self.__dict__["__data__"].keys())
def __len__(self) -> int:
return len(self.__dict__["__data__"])
def __immutable__(self) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __setitem__(self, key: K, value: Any) -> NoReturn:
self.__immutable__()
def __delitem__(self, key: K) -> NoReturn:
self.__immutable__()
def __hash__(self) -> int:
return cast(int, self.__dict__["__final_hash__"])
def __eq__(self, other: Any) -> bool:
if not isinstance(other, FrozenMapping):
return NotImplemented
return hash(self) == hash(other)
def __repr__(self) -> str:
return repr(self.__dict__["__data__"])
def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str:
def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items])
prefix = f"{name}("
postfix = ")"
body = to_str(", ")
line_length = int(os.environ.get("COLUMNS", 80))
body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length
multiline_body = len(str(body).splitlines()) > 1
if not (body_too_long or multiline_body):
return prefix + body + postfix
body = textwrap.indent(to_str(",\n"), " " * 2)
return f"{prefix}\n{body}\n{postfix}"
class FrozenBunch(FrozenMapping):
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError as error:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error
def __setattr__(self, key: Any, value: Any) -> NoReturn:
self.__immutable__()
def __delattr__(self, item: Any) -> NoReturn:
self.__immutable__()
def __repr__(self) -> str:
return make_repr(type(self).__name__, self.items())
def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray:
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
return bytearray(file.read(-1 if count == -1 else count * item_size)) return bytearray(file.read(-1 if count == -1 else count * item_size))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment