"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "37c82480bbd7c97a9f2d9796eb368a54e666334d"
Unverified Commit f630e671 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add prototype for CIFAR datasets (#4511)

* add prototype for CIFAR datasets

Conflicts:
	torchvision/prototype/datasets/_builtin/__init__.py
	torchvision/prototype/datasets/utils/_internal.py

* fix mypy

* cleanup

* more cleanup

* revert unrelated changes

* fix code format

* avoid decoding twice by default

* revert unrelated change

* cleanup
parent a75dc89a
...@@ -4,8 +4,8 @@ from typing import Any, Callable, Dict, List, Optional ...@@ -4,8 +4,8 @@ from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets import home from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.decoder import pil from torchvision.prototype.datasets.decoder import raw, pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType
from torchvision.prototype.datasets.utils._internal import add_suggestion from torchvision.prototype.datasets.utils._internal import add_suggestion
from . import _builtin from . import _builtin
...@@ -48,15 +48,26 @@ def info(name: str) -> DatasetInfo: ...@@ -48,15 +48,26 @@ def info(name: str) -> DatasetInfo:
return find(name).info return find(name).info
default = object()
DEFAULT_DECODER: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = {
DatasetType.RAW: raw,
DatasetType.IMAGE: pil,
}
def load( def load(
name: str, name: str,
*, *,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = default, # type: ignore[assignment]
split: str = "train", split: str = "train",
**options: Any, **options: Any,
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
dataset = find(name) dataset = find(name)
if decoder is default:
decoder = DEFAULT_DECODER.get(dataset.info.type)
config = dataset.info.make_config(split=split, **options) config = dataset.info.make_config(split=split, **options)
root = home() / name root = home() / name
......
from .caltech import Caltech101, Caltech256 from .caltech import Caltech101, Caltech256
from .cifar import Cifar10, Cifar100
...@@ -19,6 +19,7 @@ from torchvision.prototype.datasets.utils import ( ...@@ -19,6 +19,7 @@ from torchvision.prototype.datasets.utils import (
DatasetInfo, DatasetInfo,
HttpResource, HttpResource,
OnlineResource, OnlineResource,
DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import create_categories_file, INFINITE_BUFFER_SIZE, read_mat from torchvision.prototype.datasets.utils._internal import create_categories_file, INFINITE_BUFFER_SIZE, read_mat
...@@ -30,6 +31,7 @@ class Caltech101(Dataset): ...@@ -30,6 +31,7 @@ class Caltech101(Dataset):
def info(self) -> DatasetInfo: def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"caltech101", "caltech101",
type=DatasetType.IMAGE,
categories=HERE / "caltech101.categories", categories=HERE / "caltech101.categories",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
) )
...@@ -146,6 +148,7 @@ class Caltech256(Dataset): ...@@ -146,6 +148,7 @@ class Caltech256(Dataset):
def info(self) -> DatasetInfo: def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"caltech256", "caltech256",
type=DatasetType.IMAGE,
categories=HERE / "caltech256.categories", categories=HERE / "caltech256.categories",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256", homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256",
) )
......
import abc
import functools
import io
import pathlib
import pickle
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TypeVar
import numpy as np
import torch
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import (
Demultiplexer,
Filter,
Mapper,
TarArchiveReader,
Shuffler,
)
from torchdata.datapipes.iter import KeyZipper
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
create_categories_file,
MappingIterator,
SequenceIterator,
INFINITE_BUFFER_SIZE,
image_buffer_from_array,
Enumerator,
)
__all__ = ["Cifar10", "Cifar100"]
HERE = pathlib.Path(__file__).parent
D = TypeVar("D")
class _CifarBase(Dataset):
@abc.abstractmethod
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> Optional[int]:
pass
@abc.abstractmethod
def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]:
pass
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data
return pickle.load(file, encoding="latin1")
def _remove_data_dict_key(self, data: Tuple[str, D]) -> D:
return data[1]
def _key_fn(self, data: Tuple[int, Any]) -> int:
return data[0]
def _collate_and_decode(
self,
data: Tuple[Tuple[int, int], Tuple[int, np.ndarray]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
(_, category_idx), (_, image_array_flat) = data
category = self.categories[category_idx]
label = torch.tensor(category_idx)
image_array = image_array_flat.reshape((3, 32, 32))
image: Union[torch.Tensor, io.BytesIO]
if decoder is raw:
image = torch.from_numpy(image_array)
else:
image_buffer = image_buffer_from_array(image_array.transpose(1, 2, 0))
image = decoder(image_buffer) if decoder else image_buffer
return dict(label=label, category=category, image=image)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
archive_dp = TarArchiveReader(archive_dp)
archive_dp: IterDataPipe = Filter(archive_dp, functools.partial(self._is_data_file, config=config))
archive_dp: IterDataPipe = Mapper(archive_dp, self._unpickle)
archive_dp = MappingIterator(archive_dp)
images_dp, labels_dp = Demultiplexer(
archive_dp,
2,
self._split_data_file, # type: ignore[arg-type]
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
labels_dp: IterDataPipe = Mapper(labels_dp, self._remove_data_dict_key)
labels_dp: IterDataPipe = SequenceIterator(labels_dp)
labels_dp = Enumerator(labels_dp)
labels_dp = Shuffler(labels_dp, buffer_size=INFINITE_BUFFER_SIZE)
images_dp: IterDataPipe = Mapper(images_dp, self._remove_data_dict_key)
images_dp: IterDataPipe = SequenceIterator(images_dp)
images_dp = Enumerator(images_dp)
dp = KeyZipper(labels_dp, images_dp, self._key_fn, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
@property
@abc.abstractmethod
def _meta_file_name(self) -> str:
pass
@property
@abc.abstractmethod
def _categories_key(self) -> str:
pass
def _is_meta_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name == self._meta_file_name
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_meta_file)
dp: IterDataPipe = Mapper(dp, self._unpickle)
categories = next(iter(dp))[self._categories_key]
create_categories_file(HERE, self.name, categories)
class Cifar10(_CifarBase):
@property
def info(self) -> DatasetInfo:
return DatasetInfo(
"cifar10",
type=DatasetType.RAW,
categories=HERE / "cifar10.categories",
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [
HttpResource(
"https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz",
sha256="6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce",
)
]
def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name.startswith("data" if config.split == "train" else "test")
def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data
if key == "data":
return 0
elif key == "labels":
return 1
else:
return None
@property
def _meta_file_name(self) -> str:
return "batches.meta"
@property
def _categories_key(self) -> str:
return "label_names"
class Cifar100(_CifarBase):
@property
def info(self) -> DatasetInfo:
return DatasetInfo(
"cifar100",
type=DatasetType.RAW,
categories=HERE / "cifar100.categories",
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
valid_options=dict(
split=("train", "test"),
),
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [
HttpResource(
"https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz",
sha256="85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7",
)
]
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name == config.split
def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data
if key == "data":
return 0
elif key == "fine_labels":
return 1
else:
return None
@property
def _meta_file_name(self) -> str:
return "meta"
@property
def _categories_key(self) -> str:
return "fine_label_names"
if __name__ == "__main__":
from torchvision.prototype.datasets import home
root = home()
Cifar10().generate_categories_file(root)
Cifar100().generate_categories_file(root)
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
apple
aquarium_fish
baby
bear
beaver
bed
bee
beetle
bicycle
bottle
bowl
boy
bridge
bus
butterfly
camel
can
castle
caterpillar
cattle
chair
chimpanzee
clock
cloud
cockroach
couch
crab
crocodile
cup
dinosaur
dolphin
elephant
flatfish
forest
fox
girl
hamster
house
kangaroo
keyboard
lamp
lawn_mower
leopard
lion
lizard
lobster
man
maple_tree
motorcycle
mountain
mouse
mushroom
oak_tree
orange
orchid
otter
palm_tree
pear
pickup_truck
pine_tree
plain
plate
poppy
porcupine
possum
rabbit
raccoon
ray
road
rocket
rose
sea
seal
shark
shrew
skunk
skyscraper
snail
snake
spider
squirrel
streetcar
sunflower
sweet_pepper
table
tank
telephone
television
tiger
tractor
train
trout
tulip
turtle
wardrobe
whale
willow_tree
wolf
woman
worm
...@@ -4,7 +4,11 @@ import PIL.Image ...@@ -4,7 +4,11 @@ import PIL.Image
import torch import torch
from torchvision.transforms.functional import pil_to_tensor from torchvision.transforms.functional import pil_to_tensor
__all__ = ["pil"] __all__ = ["raw", "pil"]
def raw(buffer: io.IOBase) -> torch.Tensor:
raise RuntimeError("This is just a sentinel and should never be called.")
def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor: def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor:
......
from . import _internal from . import _internal
from ._dataset import DatasetConfig, DatasetInfo, Dataset from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset
from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource
import abc import abc
import enum
import io import io
import os import os
import pathlib import pathlib
...@@ -45,6 +46,11 @@ def make_repr(name: str, items: Iterable[Tuple[str, Any]]): ...@@ -45,6 +46,11 @@ def make_repr(name: str, items: Iterable[Tuple[str, Any]]):
return f"{prefix}\n{body}\n{postfix}" return f"{prefix}\n{body}\n{postfix}"
class DatasetType(enum.Enum):
RAW = enum.auto()
IMAGE = enum.auto()
class DatasetConfig(Mapping): class DatasetConfig(Mapping):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
data = dict(*args, **kwargs) data = dict(*args, **kwargs)
...@@ -96,6 +102,7 @@ class DatasetInfo: ...@@ -96,6 +102,7 @@ class DatasetInfo:
self, self,
name: str, name: str,
*, *,
type: Union[str, DatasetType],
categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None,
citation: Optional[str] = None, citation: Optional[str] = None,
homepage: Optional[str] = None, homepage: Optional[str] = None,
...@@ -103,6 +110,7 @@ class DatasetInfo: ...@@ -103,6 +110,7 @@ class DatasetInfo:
valid_options: Optional[Dict[str, Sequence]] = None, valid_options: Optional[Dict[str, Sequence]] = None,
) -> None: ) -> None:
self.name = name.lower() self.name = name.lower()
self.type = DatasetType[type.upper()] if isinstance(type, str) else type
if categories is None: if categories is None:
categories = [] categories = []
...@@ -111,7 +119,7 @@ class DatasetInfo: ...@@ -111,7 +119,7 @@ class DatasetInfo:
elif isinstance(categories, (str, pathlib.Path)): elif isinstance(categories, (str, pathlib.Path)):
with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh: with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh:
categories = [line.strip() for line in fh] categories = [line.strip() for line in fh]
self.categories = categories self.categories = tuple(categories)
self.citation = citation self.citation = citation
self.homepage = homepage self.homepage = homepage
...@@ -181,6 +189,10 @@ class Dataset(abc.ABC): ...@@ -181,6 +189,10 @@ class Dataset(abc.ABC):
def default_config(self) -> DatasetConfig: def default_config(self) -> DatasetConfig:
return self.info.default_config return self.info.default_config
@property
def categories(self) -> Tuple[str, ...]:
return self.info.categories
@abc.abstractmethod @abc.abstractmethod
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def resources(self, config: DatasetConfig) -> List[OnlineResource]:
pass pass
......
...@@ -2,10 +2,29 @@ import collections.abc ...@@ -2,10 +2,29 @@ import collections.abc
import difflib import difflib
import io import io
import pathlib import pathlib
from typing import Collection, Sequence, Callable, Union, Any from typing import Collection, Sequence, Callable, Union, Iterator, Tuple, TypeVar, Dict, Any
import numpy as np
import PIL.Image
from torch.utils.data import IterDataPipe
__all__ = [
"INFINITE_BUFFER_SIZE",
"sequence_to_str",
"add_suggestion",
"create_categories_file",
"read_mat",
"image_buffer_from_array",
"SequenceIterator",
"MappingIterator",
"Enumerator",
]
K = TypeVar("K")
D = TypeVar("D")
__all__ = ["INFINITE_BUFFER_SIZE", "sequence_to_str", "add_suggestion", "create_categories_file", "read_mat"]
# pseudo-infinite until a true infinite buffer is supported by all datapipes # pseudo-infinite until a true infinite buffer is supported by all datapipes
INFINITE_BUFFER_SIZE = 1_000_000_000 INFINITE_BUFFER_SIZE = 1_000_000_000
...@@ -47,3 +66,39 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: ...@@ -47,3 +66,39 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error
return sio.loadmat(buffer, **kwargs) return sio.loadmat(buffer, **kwargs)
def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.BytesIO:
image = PIL.Image.fromarray(array)
buffer = io.BytesIO()
image.save(buffer, format=format)
buffer.seek(0)
return buffer
class SequenceIterator(IterDataPipe[D]):
def __init__(self, datapipe: IterDataPipe[Sequence[D]]):
self.datapipe = datapipe
def __iter__(self) -> Iterator[D]:
for sequence in self.datapipe:
yield from iter(sequence)
class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None:
self.datapipe = datapipe
self.drop_key = drop_key
def __iter__(self) -> Iterator[Union[Tuple[K, D], D]]:
for mapping in self.datapipe:
yield from iter(mapping.values() if self.drop_key else mapping.items()) # type: ignore[call-overload]
class Enumerator(IterDataPipe[Tuple[int, D]]):
def __init__(self, datapipe: IterDataPipe[D], start: int = 0) -> None:
self.datapipe = datapipe
self.start = start
def __iter__(self) -> Iterator[Tuple[int, D]]:
yield from enumerate(self.datapipe, self.start)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment