"vscode:/vscode.git/clone" did not exist on "22b19d578e57f9b152eef4444738da68bbb33ce7"
Unverified Commit 1ac6e8b9 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Refactor and simplify prototype datasets (#5778)



* refactor prototype datasets to inherit from IterDataPipe (#5448)

* refactor prototype datasets to inherit from IterDataPipe

* depend on new architecture

* fix missing file detection

* remove unrelated file

* reinstante decorator for mock registering

* options -> config

* remove passing of info to mock data functions

* refactor categories file generation

* fix imagenet

* fix prototype datasets data loading tests (#5711)

* reenable serialization test

* cleanup

* fix dill test

* trigger CI

* patch DILL_AVAILABLE for pickle serialization

* revert CI changes

* remove dill test and traversable test

* add data loader test

* parametrize over only_datapipe

* draw one sample rather than exhaust data loader

* cleanup

* trigger CI

* migrate VOC prototype dataset (#5743)

* migrate VOC prototype dataset

* cleanup

* revert unrelated mock data changes

* remove categories annotations

* move properties to constructor

* readd homepage

* migrate CIFAR prototype datasets (#5751)

* migrate country211 prototype dataset (#5753)

* migrate CLEVR prototype datsaet (#5752)

* migrate coco prototype (#5473)

* migrate coco prototype

* revert unrelated change

* add kwargs to super constructor call

* remove unneeded changes

* fix docstring position

* make kwargs explicit

* add dependencies to docstring

* fix missing dependency message

* Migrate PCAM prototype dataset (#5745)

* Port PCAM

* skip_integrity_check

* Update torchvision/prototype/datasets/_builtin/pcam.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Address comments
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate DTD prototype dataset (#5757)

* Migrate DTD prototype dataset

* Docstring

* Apply suggestions from code review
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate GTSRB prototype dataset (#5746)

* Migrate GTSRB prototype dataset

* ufmt

* Address comments

* Apparently mypy doesn't know that __len__ returns ints. How cute.

* why is the CI not triggered??

* Update torchvision/prototype/datasets/_builtin/gtsrb.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* migrate CelebA prototype dataset (#5750)

* migrate CelebA prototype dataset

* inline split_id

* Migrate Food101 prototype dataset (#5758)

* Migrate Food101 dataset

* Added length

* Update torchvision/prototype/datasets/_builtin/food101.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate Fer2013 prototype dataset (#5759)

* Migrate Fer2013 prototype dataset

* Update torchvision/prototype/datasets/_builtin/fer2013.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate EuroSAT prototype dataset (#5760)

* Migrate Semeion prototype dataset (#5761)

* migrate caltech prototype datasets (#5749)

* migrate caltech prototype datasets

* resolve third party dependencies

* Migrate Oxford Pets prototype dataset (#5764)

* Migrate Oxford Pets prototype dataset

* Update torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* migrate mnist prototype datasets (#5480)

* migrate MNIST prototype datasets

* Update torchvision/prototype/datasets/_builtin/mnist.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Migrate Stanford Cars prototype dataset (#5767)

* Migrate Stanford Cars prototype dataset

* Address comments

* fix category file generation (#5770)

* fix category file generation

* revert unrelated change

* revert unrelated change

* migrate cub200 prototype dataset (#5765)

* migrate cub200 prototype dataset

* address comments

* fix category-file-generation

* Migrate USPS prototype dataset (#5771)

* migrate SBD prototype dataset (#5772)

* migrate SBD prototype dataset

* reuse categories

* Migrate SVHN prototype dataset (#5769)

* add test to enforce __len__ is working on prototype datasets (#5742)

* reactivate special dataset tests

* add missing annotation

* Cleanup prototype dataset implementation (#5774)

* Remove Dataset2 class

* Move read_categories_file out of DatasetInfo

* Remove FrozenBunch and FrozenMapping

* Remove test_prototype_datasets_api.py and move missing dep test somewhere else

* ufmt

* Let read_categories_file accept names instead of paths

* Mypy

* flake8

* fix category file reading
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* update prototype dataset README (#5777)

* update prototype dataset README

* fix header level

* Apply suggestions from code review
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 5f74f031
import enum
import functools
import pathlib
import re
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast, Union
from torchdata.datapipes.iter import (
IterDataPipe,
......@@ -14,23 +15,30 @@ from torchdata.datapipes.iter import (
Enumerator,
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource,
ManualDownloadResource,
Dataset,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
path_comparator,
getitem,
read_mat,
hint_sharding,
hint_shuffling,
read_categories_file,
path_accessor,
)
from torchvision.prototype.features import Label, EncodedImage
from torchvision.prototype.utils._internal import FrozenMapping
from .._api import register_dataset, register_info
NAME = "imagenet"
@register_info(NAME)
def _info() -> Dict[str, Any]:
categories, wnids = zip(*read_categories_file(NAME))
return dict(categories=categories, wnids=wnids)
class ImageNetResource(ManualDownloadResource):
......@@ -38,32 +46,33 @@ class ImageNetResource(ManualDownloadResource):
super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs)
class ImageNetDemux(enum.IntEnum):
META = 0
LABEL = 1
@register_dataset(NAME)
class ImageNet(Dataset):
def _make_info(self) -> DatasetInfo:
name = "imagenet"
categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories"))
return DatasetInfo(
name,
dependencies=("scipy",),
categories=categories,
homepage="https://www.image-net.org/",
valid_options=dict(split=("train", "val", "test")),
extra=dict(
wnid_to_category=FrozenMapping(zip(wnids, categories)),
category_to_wnid=FrozenMapping(zip(categories, wnids)),
sizes=FrozenMapping(
[
(DatasetConfig(split="train"), 1_281_167),
(DatasetConfig(split="val"), 50_000),
(DatasetConfig(split="test"), 100_000),
]
),
),
)
"""
- **homepage**: https://www.image-net.org/
"""
def supports_sharded(self) -> bool:
return True
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val", "test"})
info = _info()
categories, wnids = info["categories"], info["wnids"]
self._categories = categories
self._wnids = wnids
self._wnid_to_category = dict(zip(wnids, categories))
super().__init__(root, skip_integrity_check=skip_integrity_check)
_IMAGES_CHECKSUMS = {
"train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
......@@ -71,15 +80,15 @@ class ImageNet(Dataset):
"test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4",
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
name = "test_v10102019" if config.split == "test" else config.split
def _resources(self) -> List[OnlineResource]:
name = "test_v10102019" if self._split == "test" else self._split
images = ImageNetResource(
file_name=f"ILSVRC2012_img_{name}.tar",
sha256=self._IMAGES_CHECKSUMS[name],
)
resources: List[OnlineResource] = [images]
if config.split == "val":
if self._split == "val":
devkit = ImageNetResource(
file_name="ILSVRC2012_devkit_t12.tar.gz",
sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953",
......@@ -88,19 +97,12 @@ class ImageNet(Dataset):
return resources
def num_samples(self, config: DatasetConfig) -> int:
return {
"train": 1_281_167,
"val": 50_000,
"test": 100_000,
}[config.split]
_TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG")
def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]:
path = pathlib.Path(data[0])
wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"]
label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories)
label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories)
return (label, wnid), data
def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]:
......@@ -108,10 +110,17 @@ class ImageNet(Dataset):
def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
return {
"meta.mat": 0,
"ILSVRC2012_validation_ground_truth.txt": 1,
"meta.mat": ImageNetDemux.META,
"ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL,
}.get(pathlib.Path(data[0]).name)
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
_WNID_MAP = {
"n03126707": "construction crane",
"n03710721": "tank suit",
}
def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]:
synsets = read_mat(data[1], squeeze_me=True)["synsets"]
return [
......@@ -121,21 +130,20 @@ class ImageNet(Dataset):
if num_children == 0
]
def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: List[str]) -> str:
def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str:
return wnids[int(imagenet_label) - 1]
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
def _val_test_image_key(self, data: Tuple[str, Any]) -> int:
path = pathlib.Path(data[0])
return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr]
def _val_test_image_key(self, path: pathlib.Path) -> int:
return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name)["id"]) # type: ignore[index]
def _prepare_val_data(
self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]]
) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]:
label_data, image_data = data
_, wnid = label_data
label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories)
label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories)
return (label, wnid), image_data
def _prepare_sample(
......@@ -150,19 +158,17 @@ class ImageNet(Dataset):
image=EncodedImage.from_file(buffer),
)
def _make_datapipe(
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
if config.split in {"train", "test"}:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
if self._split in {"train", "test"}:
dp = resource_dps[0]
# the train archive is a tar of tars
if config.split == "train":
if self._split == "train":
dp = TarArchiveLoader(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
dp = Mapper(dp, self._prepare_train_data if config.split == "train" else self._prepare_test_data)
dp = Mapper(dp, self._prepare_train_data if self._split == "train" else self._prepare_test_data)
else: # config.split == "val":
images_dp, devkit_dp = resource_dps
......@@ -174,6 +180,7 @@ class ImageNet(Dataset):
_, wnids = zip(*next(iter(meta_dp)))
label_dp = LineReader(label_dp, decode=True, return_path=False)
# We cannot use self._wnids here, since we use a different order than the dataset
label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids))
label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
label_dp = hint_shuffling(label_dp)
......@@ -183,26 +190,29 @@ class ImageNet(Dataset):
label_dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=self._val_test_image_key,
ref_key_fn=path_accessor(self._val_test_image_key),
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = Mapper(dp, self._prepare_val_data)
return Mapper(dp, self._prepare_sample)
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
_WNID_MAP = {
"n03126707": "construction crane",
"n03710721": "tank suit",
}
def __len__(self) -> int:
return {
"train": 1_281_167,
"val": 50_000,
"test": 100_000,
}[self._split]
def _filter_meta(self, data: Tuple[str, Any]) -> bool:
return self._classifiy_devkit(data) == ImageNetDemux.META
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]:
config = self.info.make_config(split="val")
resources = self.resources(config)
def _generate_categories(self) -> List[Tuple[str, ...]]:
self._split = "val"
resources = self._resources()
devkit_dp = resources[1].load(root)
meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))
devkit_dp = resources[1].load(self._root)
meta_dp = Filter(devkit_dp, self._filter_meta)
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp)))
......
......@@ -7,12 +7,13 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, U
import torch
from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, Mapper, Zipper, Decompressor
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling
from torchvision.prototype.features import Image, Label
from torchvision.prototype.utils._internal import fromfile
__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"]
from .._api import register_dataset, register_info
prod = functools.partial(functools.reduce, operator.mul)
......@@ -61,14 +62,14 @@ class _MNISTBase(Dataset):
_URL_BASE: Union[str, Sequence[str]]
@abc.abstractmethod
def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]:
def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
pass
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
(images_file, images_sha256), (
labels_file,
labels_sha256,
) = self._files_and_checksums(config)
) = self._files_and_checksums()
url_bases = self._URL_BASE
if isinstance(url_bases, str):
......@@ -82,21 +83,21 @@ class _MNISTBase(Dataset):
return [images, labels]
def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]:
def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]:
return None, None
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]:
_categories: List[str]
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]:
image, label = data
return dict(
image=Image(image),
label=Label(label, dtype=torch.int64, categories=self.categories),
label=Label(label, dtype=torch.int64, categories=self._categories),
)
def _make_datapipe(
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, labels_dp = resource_dps
start, stop = self.start_and_stop(config)
start, stop = self.start_and_stop()
images_dp = Decompressor(images_dp)
images_dp = MNISTFileReader(images_dp, start=start, stop=stop)
......@@ -107,19 +108,31 @@ class _MNISTBase(Dataset):
dp = Zipper(images_dp, labels_dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, functools.partial(self._prepare_sample, config=config))
return Mapper(dp, self._prepare_sample)
@register_info("mnist")
def _mnist_info() -> Dict[str, Any]:
return dict(
categories=[str(label) for label in range(10)],
)
@register_dataset("mnist")
class MNIST(_MNISTBase):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"mnist",
categories=10,
homepage="http://yann.lecun.com/exdb/mnist",
valid_options=dict(
split=("train", "test"),
),
)
"""
- **homepage**: http://yann.lecun.com/exdb/mnist
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL_BASE: Union[str, Sequence[str]] = (
"http://yann.lecun.com/exdb/mnist",
......@@ -132,8 +145,8 @@ class MNIST(_MNISTBase):
"t10k-labels-idx1-ubyte.gz": "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6",
}
def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]:
prefix = "train" if config.split == "train" else "t10k"
def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
prefix = "train" if self._split == "train" else "t10k"
images_file = f"{prefix}-images-idx3-ubyte.gz"
labels_file = f"{prefix}-labels-idx1-ubyte.gz"
return (images_file, self._CHECKSUMS[images_file]), (
......@@ -141,28 +154,35 @@ class MNIST(_MNISTBase):
self._CHECKSUMS[labels_file],
)
_categories = _mnist_info()["categories"]
def __len__(self) -> int:
return 60_000 if self._split == "train" else 10_000
@register_info("fashionmnist")
def _fashionmnist_info() -> Dict[str, Any]:
return dict(
categories=[
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
],
)
@register_dataset("fashionmnist")
class FashionMNIST(MNIST):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"fashionmnist",
categories=(
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
),
homepage="https://github.com/zalandoresearch/fashion-mnist",
valid_options=dict(
split=("train", "test"),
),
)
"""
- **homepage**: https://github.com/zalandoresearch/fashion-mnist
"""
_URL_BASE = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com"
_CHECKSUMS = {
......@@ -172,17 +192,21 @@ class FashionMNIST(MNIST):
"t10k-labels-idx1-ubyte.gz": "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5",
}
_categories = _fashionmnist_info()["categories"]
@register_info("kmnist")
def _kmnist_info() -> Dict[str, Any]:
return dict(
categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"],
)
@register_dataset("kmnist")
class KMNIST(MNIST):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"kmnist",
categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"],
homepage="http://codh.rois.ac.jp/kmnist/index.html.en",
valid_options=dict(
split=("train", "test"),
),
)
"""
- **homepage**: http://codh.rois.ac.jp/kmnist/index.html.en
"""
_URL_BASE = "http://codh.rois.ac.jp/kmnist/dataset/kmnist"
_CHECKSUMS = {
......@@ -192,36 +216,46 @@ class KMNIST(MNIST):
"t10k-labels-idx1-ubyte.gz": "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c",
}
_categories = _kmnist_info()["categories"]
@register_info("emnist")
def _emnist_info() -> Dict[str, Any]:
return dict(
categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase),
)
@register_dataset("emnist")
class EMNIST(_MNISTBase):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"emnist",
categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase),
homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist",
valid_options=dict(
split=("train", "test"),
image_set=(
"Balanced",
"By_Merge",
"By_Class",
"Letters",
"Digits",
"MNIST",
),
),
"""
- **homepage**: https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
image_set: str = "Balanced",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test"))
self._image_set = self._verify_str_arg(
image_set, "image_set", ("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST")
)
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL_BASE = "https://rds.westernsydney.edu.au/Institutes/MARCS/BENS/EMNIST"
def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]:
prefix = f"emnist-{config.image_set.replace('_', '').lower()}-{config.split}"
def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
prefix = f"emnist-{self._image_set.replace('_', '').lower()}-{self._split}"
images_file = f"{prefix}-images-idx3-ubyte.gz"
labels_file = f"{prefix}-labels-idx1-ubyte.gz"
# Since EMNIST provides the data files inside an archive, we don't need provide checksums for them
# Since EMNIST provides the data files inside an archive, we don't need to provide checksums for them
return (images_file, ""), (labels_file, "")
def resources(self, config: Optional[DatasetConfig] = None) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
f"{self._URL_BASE}/emnist-gzip.zip",
......@@ -229,9 +263,9 @@ class EMNIST(_MNISTBase):
)
]
def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]:
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
(images_file, _), (labels_file, _) = self._files_and_checksums(config)
(images_file, _), (labels_file, _) = self._files_and_checksums()
if path.name == images_file:
return 0
elif path.name == labels_file:
......@@ -239,6 +273,8 @@ class EMNIST(_MNISTBase):
else:
return None
_categories = _emnist_info()["categories"]
_LABEL_OFFSETS = {
38: 1,
39: 1,
......@@ -251,45 +287,71 @@ class EMNIST(_MNISTBase):
46: 9,
}
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]:
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]:
# In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper).
# That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense,
# i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example,
# since there is no 'c', 'd' corresponds to
# i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create these gaps. For
# example, since there is no 'c', 'd' corresponds to
# label 38 (10 digits + 26 uppercase letters + 3rd unmerged lower case letter - 1 for zero indexing),
# and at the same time corresponds to
# index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing)
# in self.categories. Thus, we need to add 1 to the label to correct this.
if config.image_set in ("Balanced", "By_Merge"):
# in self._categories. Thus, we need to add 1 to the label to correct this.
if self._image_set in ("Balanced", "By_Merge"):
image, label = data
label += self._LABEL_OFFSETS.get(int(label), 0)
data = (image, label)
return super()._prepare_sample(data, config=config)
return super()._prepare_sample(data)
def _make_datapipe(
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
images_dp, labels_dp = Demultiplexer(
archive_dp,
2,
functools.partial(self._classify_archive, config=config),
self._classify_archive,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
return super()._make_datapipe([images_dp, labels_dp], config=config)
return super()._datapipe([images_dp, labels_dp])
def __len__(self) -> int:
return {
("train", "Balanced"): 112_800,
("train", "By_Merge"): 697_932,
("train", "By_Class"): 697_932,
("train", "Letters"): 124_800,
("train", "Digits"): 240_000,
("train", "MNIST"): 60_000,
("test", "Balanced"): 18_800,
("test", "By_Merge"): 116_323,
("test", "By_Class"): 116_323,
("test", "Letters"): 20_800,
("test", "Digits"): 40_000,
("test", "MNIST"): 10_000,
}[(self._split, self._image_set)]
@register_info("qmnist")
def _qmnist_info() -> Dict[str, Any]:
return dict(
categories=[str(label) for label in range(10)],
)
@register_dataset("qmnist")
class QMNIST(_MNISTBase):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"qmnist",
categories=10,
homepage="https://github.com/facebookresearch/qmnist",
valid_options=dict(
split=("train", "test", "test10k", "test50k", "nist"),
),
)
"""
- **homepage**: https://github.com/facebookresearch/qmnist
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test", "test10k", "test50k", "nist"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL_BASE = "https://raw.githubusercontent.com/facebookresearch/qmnist/master"
_CHECKSUMS = {
......@@ -301,9 +363,9 @@ class QMNIST(_MNISTBase):
"xnist-labels-idx2-int.xz": "db042968723ec2b7aed5f1beac25d2b6e983b9286d4f4bf725f1086e5ae55c4f",
}
def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]:
prefix = "xnist" if config.split == "nist" else f"qmnist-{'train' if config.split== 'train' else 'test'}"
suffix = "xz" if config.split == "nist" else "gz"
def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
prefix = "xnist" if self._split == "nist" else f"qmnist-{'train' if self._split == 'train' else 'test'}"
suffix = "xz" if self._split == "nist" else "gz"
images_file = f"{prefix}-images-idx3-ubyte.{suffix}"
labels_file = f"{prefix}-labels-idx2-int.{suffix}"
return (images_file, self._CHECKSUMS[images_file]), (
......@@ -311,13 +373,13 @@ class QMNIST(_MNISTBase):
self._CHECKSUMS[labels_file],
)
def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]:
def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]:
start: Optional[int]
stop: Optional[int]
if config.split == "test10k":
if self._split == "test10k":
start = 0
stop = 10000
elif config.split == "test50k":
elif self._split == "test50k":
start = 10000
stop = None
else:
......@@ -325,10 +387,12 @@ class QMNIST(_MNISTBase):
return start, stop
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]:
_categories = _emnist_info()["categories"]
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]:
image, ann = data
label, *extra_anns = ann
sample = super()._prepare_sample((image, label), config=config)
sample = super()._prepare_sample((image, label))
sample.update(
dict(
......@@ -340,3 +404,12 @@ class QMNIST(_MNISTBase):
)
sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]])))
return sample
def __len__(self) -> int:
return {
"train": 60_000,
"test": 60_000,
"test10k": 10_000,
"test50k": 50_000,
"nist": 402_953,
}[self._split]
import enum
import pathlib
from typing import Any, Dict, List, Optional, Tuple, BinaryIO
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
......@@ -16,27 +14,41 @@ from torchvision.prototype.datasets.utils._internal import (
hint_shuffling,
getitem,
path_accessor,
read_categories_file,
path_comparator,
)
from torchvision.prototype.features import Label, EncodedImage
from .._api import register_dataset, register_info
class OxfordIITPetDemux(enum.IntEnum):
NAME = "oxford-iiit-pet"
class OxfordIIITPetDemux(enum.IntEnum):
SPLIT_AND_CLASSIFICATION = 0
SEGMENTATIONS = 1
class OxfordIITPet(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"oxford-iiit-pet",
homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/",
valid_options=dict(
split=("trainval", "test"),
),
)
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
@register_dataset(NAME)
class OxfordIIITPet(Dataset):
"""Oxford IIIT Pet Dataset
homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/",
"""
def __init__(
self, root: Union[str, pathlib.Path], *, split: str = "trainval", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", {"trainval", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
images = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d",
......@@ -51,8 +63,8 @@ class OxfordIITPet(Dataset):
def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]:
return {
"annotations": OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION,
"trimaps": OxfordIITPetDemux.SEGMENTATIONS,
"annotations": OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION,
"trimaps": OxfordIIITPetDemux.SEGMENTATIONS,
}.get(pathlib.Path(data[0]).parent.name)
def _filter_images(self, data: Tuple[str, Any]) -> bool:
......@@ -70,7 +82,7 @@ class OxfordIITPet(Dataset):
image_path, image_buffer = image_data
return dict(
label=Label(int(classification_data["label"]) - 1, categories=self.categories),
label=Label(int(classification_data["label"]) - 1, categories=self._categories),
species="cat" if classification_data["species"] == "1" else "dog",
segmentation_path=segmentation_path,
segmentation=EncodedImage.from_file(segmentation_buffer),
......@@ -78,9 +90,7 @@ class OxfordIITPet(Dataset):
image=EncodedImage.from_file(image_buffer),
)
def _make_datapipe(
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps
images_dp = Filter(images_dp, self._filter_images)
......@@ -93,9 +103,7 @@ class OxfordIITPet(Dataset):
buffer_size=INFINITE_BUFFER_SIZE,
)
split_and_classification_dp = Filter(
split_and_classification_dp, path_comparator("name", f"{config.split}.txt")
)
split_and_classification_dp = Filter(split_and_classification_dp, path_comparator("name", f"{self._split}.txt"))
split_and_classification_dp = CSVDictParser(
split_and_classification_dp, fieldnames=("image_id", "label", "species"), delimiter=" "
)
......@@ -122,15 +130,14 @@ class OxfordIITPet(Dataset):
return Mapper(dp, self._prepare_sample)
def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION
return self._classify_anns(data) == OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION
def _generate_categories(self, root: pathlib.Path) -> List[str]:
config = self.default_config
resources = self.resources(config)
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[1].load(root)
dp = resources[1].load(self._root)
dp = Filter(dp, self._filter_split_and_classification_anns)
dp = Filter(dp, path_comparator("name", f"{config.split}.txt"))
dp = Filter(dp, path_comparator("name", "trainval.txt"))
dp = CSVDictParser(dp, fieldnames=("image_id", "label"), delimiter=" ")
raw_categories_and_labels = {(data["image_id"].rsplit("_", 1)[0], data["label"]) for data in dp}
......@@ -138,3 +145,6 @@ class OxfordIITPet(Dataset):
*sorted(raw_categories_and_labels, key=lambda raw_category_and_label: int(raw_category_and_label[1]))
)
return [" ".join(part.title() for part in raw_category.split("_")) for raw_category in raw_categories]
def __len__(self) -> int:
return 3_680 if self._split == "trainval" else 3_669
import io
import pathlib
from collections import namedtuple
from typing import Any, Dict, List, Optional, Tuple, Iterator
from typing import Any, Dict, List, Optional, Tuple, Iterator, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper
from torchvision.prototype import features
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource,
GDriveResource,
)
......@@ -17,6 +16,11 @@ from torchvision.prototype.datasets.utils._internal import (
)
from torchvision.prototype.features import Label
from .._api import register_dataset, register_info
NAME = "pcam"
class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]):
def __init__(
......@@ -40,15 +44,25 @@ class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]):
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=["0", "1"])
@register_dataset(NAME)
class PCAM(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"pcam",
homepage="https://github.com/basveeling/pcam",
categories=2,
valid_options=dict(split=("train", "test", "val")),
dependencies=["h5py"],
)
# TODO write proper docstring
"""PCAM Dataset
homepage="https://github.com/basveeling/pcam"
"""
def __init__(
self, root: Union[str, pathlib.Path], split: str = "train", *, skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("h5py",))
_RESOURCES = {
"train": (
......@@ -89,10 +103,10 @@ class PCAM(Dataset):
),
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
return [ # = [images resource, targets resource]
GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, preprocess="decompress")
for file_name, gdrive_id, sha256 in self._RESOURCES[config.split]
for file_name, gdrive_id, sha256 in self._RESOURCES[self._split]
]
def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]:
......@@ -100,12 +114,10 @@ class PCAM(Dataset):
return {
"image": features.Image(image.transpose(2, 0, 1)),
"label": Label(target.item()),
"label": Label(target.item(), categories=self._categories),
}
def _make_datapipe(
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, targets_dp = resource_dps
......@@ -116,3 +128,6 @@ class PCAM(Dataset):
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 262_144 if self._split == "train" else 32_768
import pathlib
import re
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union
import numpy as np
from torchdata.datapipes.iter import (
......@@ -11,13 +11,7 @@ from torchdata.datapipes.iter import (
IterKeyZipper,
LineReader,
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
read_mat,
......@@ -26,22 +20,42 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator,
hint_sharding,
hint_shuffling,
read_categories_file,
)
from torchvision.prototype.features import _Feature, EncodedImage
from .._api import register_dataset, register_info
NAME = "sbd"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class SBD(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"sbd",
dependencies=("scipy",),
homepage="http://home.bharathh.info/pubs/codes/SBD/download.html",
valid_options=dict(
split=("train", "val", "train_noval"),
),
)
"""
- **homepage**: http://home.bharathh.info/pubs/codes/SBD/download.html
- **dependencies**:
- <scipy `https://scipy.org`>_
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "train_noval"))
self._categories = _info()["categories"]
super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
archive = HttpResource(
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz",
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53",
......@@ -85,12 +99,7 @@ class SBD(Dataset):
segmentation=_Feature(anns["Segmentation"].item()),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp, extra_split_dp = resource_dps
archive_dp = resource_dps[0]
......@@ -101,10 +110,10 @@ class SBD(Dataset):
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
if config.split == "train_noval":
if self._split == "train_noval":
split_dp = extra_split_dp
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp)
......@@ -120,10 +129,17 @@ class SBD(Dataset):
)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
resources = self.resources(self.default_config)
def __len__(self) -> int:
return {
"train": 8_498,
"val": 2_857,
"train_noval": 5_623,
}[self._split]
def _generate_categories(self) -> Tuple[str, ...]:
resources = self._resources()
dp = resources[0].load(root)
dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", "category_names.m"))
dp = LineReader(dp)
dp = Mapper(dp, bytes.decode, input_col=1)
......
from typing import Any, Dict, List, Tuple
import pathlib
from typing import Any, Dict, List, Tuple, Union
import torch
from torchdata.datapipes.iter import (
......@@ -8,24 +9,34 @@ from torchdata.datapipes.iter import (
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import Image, OneHotLabel
from .._api import register_dataset, register_info
NAME = "semeion"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=[str(i) for i in range(10)])
@register_dataset(NAME)
class SEMEION(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"semeion",
categories=10,
homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit",
)
"""Semeion dataset
homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit",
"""
def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None:
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
data = HttpResource(
"http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data",
sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1",
......@@ -36,18 +47,16 @@ class SEMEION(Dataset):
image_data, label_data = data[:256], data[256:-1]
return dict(
image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.uint8).reshape(16, 16)),
label=OneHotLabel([int(label) for label in label_data], categories=self.categories),
image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.float).reshape(16, 16)),
label=OneHotLabel([int(label) for label in label_data], categories=self._categories),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVParser(dp, delimiter=" ")
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 1_593
import pathlib
from typing import Any, Dict, List, Tuple, Iterator, BinaryIO
from typing import Any, Dict, List, Tuple, Iterator, BinaryIO, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, path_comparator, read_mat
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
path_comparator,
read_mat,
read_categories_file,
)
from torchvision.prototype.features import BoundingBox, EncodedImage, Label
from .._api import register_dataset, register_info
class StanfordCarsLabelReader(IterDataPipe[Tuple[int, int, int, int, int, str]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]]) -> None:
......@@ -18,16 +26,31 @@ class StanfordCarsLabelReader(IterDataPipe[Tuple[int, int, int, int, int, str]])
yield tuple(ann) # type: ignore[misc]
NAME = "stanford-cars"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class StanfordCars(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
name="stanford-cars",
homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html",
dependencies=("scipy",),
valid_options=dict(
split=("test", "train"),
),
)
"""Stanford Cars dataset.
homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html",
dependencies=scipy
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",))
_URL_ROOT = "https://ai.stanford.edu/~jkrause/"
_URLS = {
......@@ -44,9 +67,9 @@ class StanfordCars(Dataset):
"car_devkit": "512b227b30e2f0a8aab9e09485786ab4479582073a144998da74d64b801fd288",
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
resources: List[OnlineResource] = [HttpResource(self._URLS[config.split], sha256=self._CHECKSUM[config.split])]
if config.split == "train":
def _resources(self) -> List[OnlineResource]:
resources: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUM[self._split])]
if self._split == "train":
resources.append(HttpResource(url=self._URLS["car_devkit"], sha256=self._CHECKSUM["car_devkit"]))
else:
......@@ -65,19 +88,14 @@ class StanfordCars(Dataset):
return dict(
path=path,
image=image,
label=Label(target[4] - 1, categories=self.categories),
label=Label(target[4] - 1, categories=self._categories),
bounding_box=BoundingBox(target[:4], format="xyxy", image_size=image.image_size),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, targets_dp = resource_dps
if config.split == "train":
if self._split == "train":
targets_dp = Filter(targets_dp, path_comparator("name", "cars_train_annos.mat"))
targets_dp = StanfordCarsLabelReader(targets_dp)
dp = Zipper(images_dp, targets_dp)
......@@ -85,12 +103,14 @@ class StanfordCars(Dataset):
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]:
config = self.info.make_config(split="train")
resources = self.resources(config)
def _generate_categories(self) -> List[str]:
resources = self._resources()
devkit_dp = resources[1].load(root)
devkit_dp = resources[1].load(self._root)
meta_dp = Filter(devkit_dp, path_comparator("name", "cars_meta.mat"))
_, meta_file = next(iter(meta_dp))
return list(read_mat(meta_file, squeeze_me=True)["class_names"])
def __len__(self) -> int:
return 8_144 if self._split == "train" else 8_041
from typing import Any, Dict, List, Tuple, BinaryIO
import pathlib
from typing import Any, Dict, List, Tuple, BinaryIO, Union
import numpy as np
from torchdata.datapipes.iter import (
......@@ -8,8 +9,6 @@ from torchdata.datapipes.iter import (
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
......@@ -20,16 +19,33 @@ from torchvision.prototype.datasets.utils._internal import (
)
from torchvision.prototype.features import Label, Image
from .._api import register_dataset, register_info
NAME = "svhn"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=[str(c) for c in range(10)])
@register_dataset(NAME)
class SVHN(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"svhn",
dependencies=("scipy",),
categories=10,
homepage="http://ufldl.stanford.edu/housenumbers/",
valid_options=dict(split=("train", "test", "extra")),
)
"""SVHN Dataset.
homepage="http://ufldl.stanford.edu/housenumbers/",
dependencies = scipy
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test", "extra"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",))
_CHECKSUMS = {
"train": "435e94d69a87fde4fd4d7f3dd208dfc32cb6ae8af2240d066de1df7508d083b8",
......@@ -37,10 +53,10 @@ class SVHN(Dataset):
"extra": "a133a4beb38a00fcdda90c9489e0c04f900b660ce8a316a5e854838379a71eb3",
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
data = HttpResource(
f"http://ufldl.stanford.edu/housenumbers/{config.split}_32x32.mat",
sha256=self._CHECKSUMS[config.split],
f"http://ufldl.stanford.edu/housenumbers/{self._split}_32x32.mat",
sha256=self._CHECKSUMS[self._split],
)
return [data]
......@@ -60,18 +76,20 @@ class SVHN(Dataset):
return dict(
image=Image(image_array.transpose((2, 0, 1))),
label=Label(int(label_array) % 10, categories=self.categories),
label=Label(int(label_array) % 10, categories=self._categories),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Mapper(dp, self._read_images_and_labels)
dp = UnBatcher(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
"train": 73_257,
"test": 26_032,
"extra": 531_131,
}[self._split]
from typing import Any, Dict, List
import pathlib
from typing import Any, Dict, List, Union
import torch
from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource
from torchvision.prototype.datasets.utils import Dataset, OnlineResource, HttpResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import Image, Label
from .._api import register_dataset, register_info
NAME = "usps"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=[str(c) for c in range(10)])
@register_dataset(NAME)
class USPS(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"usps",
homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps",
valid_options=dict(
split=("train", "test"),
),
categories=10,
)
"""USPS Dataset
homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps",
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass"
......@@ -29,8 +46,8 @@ class USPS(Dataset):
),
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [USPS._RESOURCES[config.split]]
def _resources(self) -> List[OnlineResource]:
return [USPS._RESOURCES[self._split]]
def _prepare_sample(self, line: str) -> Dict[str, Any]:
label, *values = line.strip().split(" ")
......@@ -38,17 +55,15 @@ class USPS(Dataset):
pixels = torch.tensor(values).add_(1).div_(2)
return dict(
image=Image(pixels.reshape(16, 16)),
label=Label(int(label) - 1, categories=self.categories),
label=Label(int(label) - 1, categories=self._categories),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = Decompressor(resource_dps[0])
dp = LineReader(dp, decode=True, return_path=False)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 7_291 if self._split == "train" else 2_007
import enum
import functools
import pathlib
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Callable
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Union
from xml.etree import ElementTree
from torchdata.datapipes.iter import (
......@@ -12,13 +13,7 @@ from torchdata.datapipes.iter import (
LineReader,
)
from torchvision.datasets import VOCDetection
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils import OnlineResource, HttpResource, Dataset
from torchvision.prototype.datasets.utils._internal import (
path_accessor,
getitem,
......@@ -26,34 +21,48 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator,
hint_sharding,
hint_shuffling,
read_categories_file,
)
from torchvision.prototype.features import BoundingBox, Label, EncodedImage
from .._api import register_dataset, register_info
class VOCDatasetInfo(DatasetInfo):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._configs = tuple(config for config in self._configs if config.split != "test" or config.year == "2007")
NAME = "voc"
def make_config(self, **options: Any) -> DatasetConfig:
config = super().make_config(**options)
if config.split == "test" and config.year != "2007":
raise ValueError("`split='test'` is only available for `year='2007'`")
return config
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class VOC(Dataset):
def _make_info(self) -> DatasetInfo:
return VOCDatasetInfo(
"voc",
homepage="http://host.robots.ox.ac.uk/pascal/VOC/",
valid_options=dict(
split=("train", "val", "trainval", "test"),
year=("2012", "2007", "2008", "2009", "2010", "2011"),
task=("detection", "segmentation"),
),
)
"""
- **homepage**: http://host.robots.ox.ac.uk/pascal/VOC/
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
year: str = "2012",
task: str = "detection",
skip_integrity_check: bool = False,
) -> None:
self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012"))
if split == "test" and year != "2007":
raise ValueError("`split='test'` is only available for `year='2007'`")
else:
self._split = self._verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
self._task = self._verify_str_arg(task, "task", ("detection", "segmentation"))
self._anns_folder = "Annotations" if task == "detection" else "SegmentationClass"
self._split_folder = "Main" if task == "detection" else "Segmentation"
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_TRAIN_VAL_ARCHIVES = {
"2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"),
......@@ -67,31 +76,27 @@ class VOC(Dataset):
"2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892")
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
file_name, sha256 = (self._TEST_ARCHIVES if config.split == "test" else self._TRAIN_VAL_ARCHIVES)[config.year]
archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{config.year}/{file_name}", sha256=sha256)
def _resources(self) -> List[OnlineResource]:
file_name, sha256 = (self._TEST_ARCHIVES if self._split == "test" else self._TRAIN_VAL_ARCHIVES)[self._year]
archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{self._year}/{file_name}", sha256=sha256)
return [archive]
_ANNS_FOLDER = dict(
detection="Annotations",
segmentation="SegmentationClass",
)
_SPLIT_FOLDER = dict(
detection="Main",
segmentation="Segmentation",
)
def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) -> bool:
path = pathlib.Path(data[0])
return name in path.parent.parts[-depth:]
def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]:
class _Demux(enum.IntEnum):
SPLIT = 0
IMAGES = 1
ANNS = 2
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
if self._is_in_folder(data, name="ImageSets", depth=2):
return 0
return self._Demux.SPLIT
elif self._is_in_folder(data, name="JPEGImages"):
return 1
elif self._is_in_folder(data, name=self._ANNS_FOLDER[config.task]):
return 2
return self._Demux.IMAGES
elif self._is_in_folder(data, name=self._anns_folder):
return self._Demux.ANNS
else:
return None
......@@ -111,7 +116,7 @@ class VOC(Dataset):
image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))),
),
labels=Label(
[self.categories.index(instance["name"]) for instance in instances], categories=self.categories
[self._categories.index(instance["name"]) for instance in instances], categories=self._categories
),
)
......@@ -121,8 +126,6 @@ class VOC(Dataset):
def _prepare_sample(
self,
data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]],
*,
prepare_ann_fn: Callable[[BinaryIO], Dict[str, Any]],
) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, image_data = split_and_image_data
......@@ -130,29 +133,24 @@ class VOC(Dataset):
ann_path, ann_buffer = ann_data
return dict(
prepare_ann_fn(ann_buffer),
(self._prepare_detection_ann if self._task == "detection" else self._prepare_segmentation_ann)(ann_buffer),
image_path=image_path,
image=EncodedImage.from_file(image_buffer),
ann_path=ann_path,
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp,
3,
functools.partial(self._classify_archive, config=config),
self._classify_archive,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._SPLIT_FOLDER[config.task]))
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._split_folder))
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp)
......@@ -166,25 +164,59 @@ class VOC(Dataset):
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(
dp,
functools.partial(
self._prepare_sample,
prepare_ann_fn=self._prepare_detection_ann
if config.task == "detection"
else self._prepare_segmentation_ann,
),
)
def _filter_detection_anns(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
return self._classify_archive(data, config=config) == 2
def _generate_categories(self, root: pathlib.Path) -> List[str]:
config = self.info.make_config(task="detection")
resource = self.resources(config)[0]
dp = resource.load(pathlib.Path(root) / self.name)
dp = Filter(dp, self._filter_detection_anns, fn_kwargs=dict(config=config))
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
("train", "2007", "detection"): 2_501,
("train", "2007", "segmentation"): 209,
("train", "2008", "detection"): 2_111,
("train", "2008", "segmentation"): 511,
("train", "2009", "detection"): 3_473,
("train", "2009", "segmentation"): 749,
("train", "2010", "detection"): 4_998,
("train", "2010", "segmentation"): 964,
("train", "2011", "detection"): 5_717,
("train", "2011", "segmentation"): 1_112,
("train", "2012", "detection"): 5_717,
("train", "2012", "segmentation"): 1_464,
("val", "2007", "detection"): 2_510,
("val", "2007", "segmentation"): 213,
("val", "2008", "detection"): 2_221,
("val", "2008", "segmentation"): 512,
("val", "2009", "detection"): 3_581,
("val", "2009", "segmentation"): 750,
("val", "2010", "detection"): 5_105,
("val", "2010", "segmentation"): 964,
("val", "2011", "detection"): 5_823,
("val", "2011", "segmentation"): 1_111,
("val", "2012", "detection"): 5_823,
("val", "2012", "segmentation"): 1_449,
("trainval", "2007", "detection"): 5_011,
("trainval", "2007", "segmentation"): 422,
("trainval", "2008", "detection"): 4_332,
("trainval", "2008", "segmentation"): 1_023,
("trainval", "2009", "detection"): 7_054,
("trainval", "2009", "segmentation"): 1_499,
("trainval", "2010", "detection"): 10_103,
("trainval", "2010", "segmentation"): 1_928,
("trainval", "2011", "detection"): 11_540,
("trainval", "2011", "segmentation"): 2_223,
("trainval", "2012", "detection"): 11_540,
("trainval", "2012", "segmentation"): 2_913,
("test", "2007", "detection"): 4_952,
("test", "2007", "segmentation"): 210,
}[(self._split, self._year, self._task)]
def _filter_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == self._Demux.ANNS
def _generate_categories(self) -> List[str]:
self._task = "detection"
resources = self._resources()
archive_dp = resources[0].load(self._root)
dp = Filter(archive_dp, self._filter_anns)
dp = Mapper(dp, self._parse_detection_ann, input_col=1)
return sorted({instance["name"] for _, anns in dp for instance in anns["object"]})
......@@ -2,25 +2,21 @@
import argparse
import csv
import pathlib
import sys
from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import find
from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR
def main(*names, force=False):
home = pathlib.Path(datasets.home())
for name in names:
path = BUILTIN_DIR / f"{name}.categories"
if path.exists() and not force:
continue
dataset = find(name)
dataset = datasets.load(name)
try:
categories = dataset._generate_categories(home / name)
categories = dataset._generate_categories()
except NotImplementedError:
continue
......
from . import _internal # usort: skip
from ._dataset import DatasetConfig, DatasetInfo, Dataset
from ._dataset import Dataset
from ._query import SampleQuery
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource
import abc
import csv
import importlib
import itertools
import os
import pathlib
from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection
from typing import Any, Dict, List, Optional, Sequence, Union, Collection, Iterator
from torch.utils.data import IterDataPipe
from torchvision._utils import sequence_to_str
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion
from torchvision.datasets.utils import verify_str_arg
from .._home import use_sharded_dataset
from ._internal import BUILTIN_DIR, _make_sharded_datapipe
from ._resource import OnlineResource
class DatasetConfig(FrozenBunch):
# This needs to be Frozen because we often pass configs as partial(func, config=config)
# and partial() requires the parameters to be hashable.
pass
class Dataset(IterDataPipe[Dict[str, Any]], abc.ABC):
@staticmethod
def _verify_str_arg(
value: str,
arg: Optional[str] = None,
valid_values: Optional[Collection[str]] = None,
*,
custom_msg: Optional[str] = None,
) -> str:
return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg)
class DatasetInfo:
def __init__(
self,
name: str,
*,
dependencies: Collection[str] = (),
categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None,
citation: Optional[str] = None,
homepage: Optional[str] = None,
license: Optional[str] = None,
valid_options: Optional[Dict[str, Sequence[Any]]] = None,
extra: Optional[Dict[str, Any]] = None,
self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False, dependencies: Collection[str] = ()
) -> None:
self.name = name.lower()
self.dependecies = dependencies
if categories is None:
path = BUILTIN_DIR / f"{self.name}.categories"
categories = path if path.exists() else []
if isinstance(categories, int):
categories = [str(label) for label in range(categories)]
elif isinstance(categories, (str, pathlib.Path)):
path = pathlib.Path(categories).expanduser().resolve()
categories, *_ = zip(*self.read_categories_file(path))
self.categories = tuple(categories)
self.citation = citation
self.homepage = homepage
self.license = license
self._valid_options = valid_options or dict()
self._configs = tuple(
DatasetConfig(**dict(zip(self._valid_options.keys(), combination)))
for combination in itertools.product(*self._valid_options.values())
)
self.extra = FrozenBunch(extra or dict())
@property
def default_config(self) -> DatasetConfig:
return self._configs[0]
@staticmethod
def read_categories_file(path: pathlib.Path) -> List[List[str]]:
with open(path, newline="") as file:
return [row for row in csv.reader(file)]
def make_config(self, **options: Any) -> DatasetConfig:
if not self._valid_options and options:
raise ValueError(
f"Dataset {self.name} does not take any options, "
f"but got {sequence_to_str(list(options), separate_last=' and')}."
)
for name, arg in options.items():
if name not in self._valid_options:
raise ValueError(
add_suggestion(
f"Unknown option '{name}' of dataset {self.name}.",
word=name,
possibilities=sorted(self._valid_options.keys()),
)
)
valid_args = self._valid_options[name]
if arg not in valid_args:
raise ValueError(
add_suggestion(
f"Invalid argument '{arg}' for option '{name}' of dataset {self.name}.",
word=arg,
possibilities=valid_args,
)
)
return DatasetConfig(self.default_config, **options)
def check_dependencies(self) -> None:
for dependency in self.dependecies:
for dependency in dependencies:
try:
importlib.import_module(dependency)
except ModuleNotFoundError as error:
except ModuleNotFoundError:
raise ModuleNotFoundError(
f"Dataset '{self.name}' depends on the third-party package '{dependency}'. "
f"{type(self).__name__}() depends on the third-party package '{dependency}'. "
f"Please install it, for example with `pip install {dependency}`."
) from error
def __repr__(self) -> str:
items = [("name", self.name)]
for key in ("citation", "homepage", "license"):
value = getattr(self, key)
if value is not None:
items.append((key, value))
items.extend(sorted((key, sequence_to_str(value)) for key, value in self._valid_options.items()))
return make_repr(type(self).__name__, items)
) from None
self._root = pathlib.Path(root).expanduser().resolve()
resources = [
resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources()
]
self._dp = self._datapipe(resources)
class Dataset(abc.ABC):
def __init__(self) -> None:
self._info = self._make_info()
def __iter__(self) -> Iterator[Dict[str, Any]]:
yield from self._dp
@abc.abstractmethod
def _make_info(self) -> DatasetInfo:
def _resources(self) -> List[OnlineResource]:
pass
@property
def info(self) -> DatasetInfo:
return self._info
@property
def name(self) -> str:
return self.info.name
@property
def default_config(self) -> DatasetConfig:
return self.info.default_config
@property
def categories(self) -> Tuple[str, ...]:
return self.info.categories
@abc.abstractmethod
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
pass
@abc.abstractmethod
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def __len__(self) -> int:
pass
def supports_sharded(self) -> bool:
return False
def load(
self,
root: Union[str, pathlib.Path],
*,
config: Optional[DatasetConfig] = None,
skip_integrity_check: bool = False,
) -> IterDataPipe[Dict[str, Any]]:
if not config:
config = self.info.default_config
if use_sharded_dataset() and self.supports_sharded():
root = os.path.join(root, *config.values())
dataset_size = self.info.extra["sizes"][config]
return _make_sharded_datapipe(root, dataset_size) # type: ignore[no-any-return]
self.info.check_dependencies()
resource_dps = [
resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config)
]
return self._make_datapipe(resource_dps, config=config)
def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]:
def _generate_categories(self) -> Sequence[Union[str, Sequence[str]]]:
raise NotImplementedError
import csv
import functools
import pathlib
import pickle
......@@ -9,6 +10,7 @@ from typing import (
Any,
Tuple,
TypeVar,
List,
Iterator,
Dict,
IO,
......@@ -198,3 +200,11 @@ def hint_sharding(datapipe: IterDataPipe) -> ShardingFilter:
def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]:
return Shuffler(datapipe, buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False)
def read_categories_file(name: str) -> List[Union[str, Sequence[str]]]:
path = BUILTIN_DIR / f"{name}.categories"
with open(path, newline="") as file:
rows = list(csv.reader(file))
rows = [row[0] if len(row) == 1 else row for row in rows]
return rows
......@@ -2,20 +2,13 @@ import collections.abc
import difflib
import io
import mmap
import os
import os.path
import platform
import textwrap
from typing import (
Any,
BinaryIO,
Callable,
cast,
Collection,
Iterable,
Iterator,
Mapping,
NoReturn,
Sequence,
Tuple,
TypeVar,
......@@ -30,9 +23,6 @@ from torchvision._utils import sequence_to_str
__all__ = [
"add_suggestion",
"FrozenMapping",
"make_repr",
"FrozenBunch",
"fromfile",
"ReadOnlyTensorBuffer",
"apply_recursively",
......@@ -60,82 +50,9 @@ def add_suggestion(
return f"{msg.strip()} {hint}"
K = TypeVar("K")
D = TypeVar("D")
class FrozenMapping(Mapping[K, D]):
def __init__(self, *args: Any, **kwargs: Any) -> None:
data = dict(*args, **kwargs)
self.__dict__["__data__"] = data
self.__dict__["__final_hash__"] = hash(tuple(data.items()))
def __getitem__(self, item: K) -> D:
return cast(Mapping[K, D], self.__dict__["__data__"])[item]
def __iter__(self) -> Iterator[K]:
return iter(self.__dict__["__data__"].keys())
def __len__(self) -> int:
return len(self.__dict__["__data__"])
def __immutable__(self) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __setitem__(self, key: K, value: Any) -> NoReturn:
self.__immutable__()
def __delitem__(self, key: K) -> NoReturn:
self.__immutable__()
def __hash__(self) -> int:
return cast(int, self.__dict__["__final_hash__"])
def __eq__(self, other: Any) -> bool:
if not isinstance(other, FrozenMapping):
return NotImplemented
return hash(self) == hash(other)
def __repr__(self) -> str:
return repr(self.__dict__["__data__"])
def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str:
def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items])
prefix = f"{name}("
postfix = ")"
body = to_str(", ")
line_length = int(os.environ.get("COLUMNS", 80))
body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length
multiline_body = len(str(body).splitlines()) > 1
if not (body_too_long or multiline_body):
return prefix + body + postfix
body = textwrap.indent(to_str(",\n"), " " * 2)
return f"{prefix}\n{body}\n{postfix}"
class FrozenBunch(FrozenMapping):
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError as error:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error
def __setattr__(self, key: Any, value: Any) -> NoReturn:
self.__immutable__()
def __delattr__(self, item: Any) -> NoReturn:
self.__immutable__()
def __repr__(self) -> str:
return make_repr(type(self).__name__, self.items())
def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray:
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
return bytearray(file.read(-1 if count == -1 else count * item_size))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment