Unverified Commit 2256b495 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

improve prototype CIFAR implementation (#4558)


Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent a485b8c6
......@@ -3,19 +3,17 @@ import functools
import io
import pathlib
import pickle
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TypeVar
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator
import numpy as np
import torch
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import (
Demultiplexer,
Filter,
Mapper,
TarArchiveReader,
Shuffler,
)
from torchdata.datapipes.iter import KeyZipper
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
Dataset,
......@@ -27,28 +25,35 @@ from torchvision.prototype.datasets.utils import (
)
from torchvision.prototype.datasets.utils._internal import (
create_categories_file,
MappingIterator,
SequenceIterator,
INFINITE_BUFFER_SIZE,
image_buffer_from_array,
Enumerator,
getitem,
path_comparator,
)
__all__ = ["Cifar10", "Cifar100"]
HERE = pathlib.Path(__file__).parent
D = TypeVar("D")
class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None:
self.datapipe = datapipe
self.labels_key = labels_key
def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]:
for mapping in self.datapipe:
image_arrays = mapping["data"].reshape((-1, 3, 32, 32))
category_idcs = mapping[self.labels_key]
yield from iter(zip(image_arrays, category_idcs))
class _CifarBase(Dataset):
@abc.abstractmethod
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> Optional[int]:
pass
_LABELS_KEY: str
_META_FILE_NAME: str
_CATEGORIES_KEY: str
@abc.abstractmethod
def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]:
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> Optional[int]:
pass
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
......@@ -57,21 +62,20 @@ class _CifarBase(Dataset):
def _collate_and_decode(
self,
data: Tuple[Tuple[int, int], Tuple[int, np.ndarray]],
data: Tuple[np.ndarray, int],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
(_, category_idx), (_, image_array_flat) = data
image_array, category_idx = data
category = self.categories[category_idx]
label = torch.tensor(category_idx)
image_array = image_array_flat.reshape((3, 32, 32))
image: Union[torch.Tensor, io.BytesIO]
if decoder is raw:
image = torch.from_numpy(image_array)
else:
image_buffer = image_buffer_from_array(image_array.transpose(1, 2, 0))
image_buffer = image_buffer_from_array(image_array.transpose((1, 2, 0)))
image = decoder(image_buffer) if decoder else image_buffer
return dict(label=label, category=category, image=image)
......@@ -83,55 +87,32 @@ class _CifarBase(Dataset):
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
archive_dp = TarArchiveReader(archive_dp)
archive_dp: IterDataPipe = Filter(archive_dp, functools.partial(self._is_data_file, config=config))
archive_dp: IterDataPipe = Mapper(archive_dp, self._unpickle)
archive_dp = MappingIterator(archive_dp)
images_dp, labels_dp = Demultiplexer(
archive_dp,
2,
self._split_data_file, # type: ignore[arg-type]
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
labels_dp: IterDataPipe = Mapper(labels_dp, getitem(1))
labels_dp: IterDataPipe = SequenceIterator(labels_dp)
labels_dp = Enumerator(labels_dp)
labels_dp = Shuffler(labels_dp, buffer_size=INFINITE_BUFFER_SIZE)
images_dp: IterDataPipe = Mapper(images_dp, getitem(1))
images_dp: IterDataPipe = SequenceIterator(images_dp)
images_dp = Enumerator(images_dp)
dp = KeyZipper(labels_dp, images_dp, getitem(0), buffer_size=INFINITE_BUFFER_SIZE)
dp = resource_dps[0]
dp: IterDataPipe = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, functools.partial(self._is_data_file, config=config))
dp: IterDataPipe = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
@property
@abc.abstractmethod
def _meta_file_name(self) -> str:
pass
@property
@abc.abstractmethod
def _categories_key(self) -> str:
pass
def _is_meta_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name == self._meta_file_name
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_meta_file)
dp: IterDataPipe = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp: IterDataPipe = Mapper(dp, self._unpickle)
categories = next(iter(dp))[self._categories_key]
categories = next(iter(dp))[self._CATEGORIES_KEY]
create_categories_file(HERE, self.name, categories)
class Cifar10(_CifarBase):
_LABELS_KEY = "labels"
_META_FILE_NAME = "batches.meta"
_CATEGORIES_KEY = "label_names"
def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name.startswith("data" if config.split == "train" else "test")
@property
def info(self) -> DatasetInfo:
return DatasetInfo(
......@@ -149,29 +130,16 @@ class Cifar10(_CifarBase):
)
]
def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name.startswith("data" if config.split == "train" else "test")
def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data
if key == "data":
return 0
elif key == "labels":
return 1
else:
return None
@property
def _meta_file_name(self) -> str:
return "batches.meta"
@property
def _categories_key(self) -> str:
return "label_names"
class Cifar100(_CifarBase):
_LABELS_KEY = "fine_labels"
_META_FILE_NAME = "meta"
_CATEGORIES_KEY = "fine_label_names"
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name == config.split
class Cifar100(_CifarBase):
@property
def info(self) -> DatasetInfo:
return DatasetInfo(
......@@ -192,27 +160,6 @@ class Cifar100(_CifarBase):
)
]
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name == config.split
def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data
if key == "data":
return 0
elif key == "fine_labels":
return 1
else:
return None
@property
def _meta_file_name(self) -> str:
return "meta"
@property
def _categories_key(self) -> str:
return "fine_label_names"
if __name__ == "__main__":
from torchvision.prototype.datasets import home
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment