"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "304efbb6509ff931b050fcb2f2895a8d98b1d220"
Unverified Commit 2256b495 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

improve prototype CIFAR implementation (#4558)


Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent a485b8c6
...@@ -3,19 +3,17 @@ import functools ...@@ -3,19 +3,17 @@ import functools
import io import io
import pathlib import pathlib
import pickle import pickle
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TypeVar from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import ( from torch.utils.data.datapipes.iter import (
Demultiplexer,
Filter, Filter,
Mapper, Mapper,
TarArchiveReader, TarArchiveReader,
Shuffler, Shuffler,
) )
from torchdata.datapipes.iter import KeyZipper
from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
...@@ -27,28 +25,35 @@ from torchvision.prototype.datasets.utils import ( ...@@ -27,28 +25,35 @@ from torchvision.prototype.datasets.utils import (
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
create_categories_file, create_categories_file,
MappingIterator,
SequenceIterator,
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
image_buffer_from_array, image_buffer_from_array,
Enumerator, path_comparator,
getitem,
) )
__all__ = ["Cifar10", "Cifar100"] __all__ = ["Cifar10", "Cifar100"]
HERE = pathlib.Path(__file__).parent HERE = pathlib.Path(__file__).parent
D = TypeVar("D")
class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None:
self.datapipe = datapipe
self.labels_key = labels_key
def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]:
for mapping in self.datapipe:
image_arrays = mapping["data"].reshape((-1, 3, 32, 32))
category_idcs = mapping[self.labels_key]
yield from iter(zip(image_arrays, category_idcs))
class _CifarBase(Dataset): class _CifarBase(Dataset):
@abc.abstractmethod _LABELS_KEY: str
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> Optional[int]: _META_FILE_NAME: str
pass _CATEGORIES_KEY: str
@abc.abstractmethod @abc.abstractmethod
def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]: def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> Optional[int]:
pass pass
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]: def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
...@@ -57,21 +62,20 @@ class _CifarBase(Dataset): ...@@ -57,21 +62,20 @@ class _CifarBase(Dataset):
def _collate_and_decode( def _collate_and_decode(
self, self,
data: Tuple[Tuple[int, int], Tuple[int, np.ndarray]], data: Tuple[np.ndarray, int],
*, *,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
(_, category_idx), (_, image_array_flat) = data image_array, category_idx = data
category = self.categories[category_idx] category = self.categories[category_idx]
label = torch.tensor(category_idx) label = torch.tensor(category_idx)
image_array = image_array_flat.reshape((3, 32, 32))
image: Union[torch.Tensor, io.BytesIO] image: Union[torch.Tensor, io.BytesIO]
if decoder is raw: if decoder is raw:
image = torch.from_numpy(image_array) image = torch.from_numpy(image_array)
else: else:
image_buffer = image_buffer_from_array(image_array.transpose(1, 2, 0)) image_buffer = image_buffer_from_array(image_array.transpose((1, 2, 0)))
image = decoder(image_buffer) if decoder else image_buffer image = decoder(image_buffer) if decoder else image_buffer
return dict(label=label, category=category, image=image) return dict(label=label, category=category, image=image)
...@@ -83,55 +87,32 @@ class _CifarBase(Dataset): ...@@ -83,55 +87,32 @@ class _CifarBase(Dataset):
config: DatasetConfig, config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0] dp = resource_dps[0]
archive_dp = TarArchiveReader(archive_dp) dp: IterDataPipe = TarArchiveReader(dp)
archive_dp: IterDataPipe = Filter(archive_dp, functools.partial(self._is_data_file, config=config)) dp: IterDataPipe = Filter(dp, functools.partial(self._is_data_file, config=config))
archive_dp: IterDataPipe = Mapper(archive_dp, self._unpickle) dp: IterDataPipe = Mapper(dp, self._unpickle)
archive_dp = MappingIterator(archive_dp) dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
images_dp, labels_dp = Demultiplexer( dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
archive_dp,
2,
self._split_data_file, # type: ignore[arg-type]
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
labels_dp: IterDataPipe = Mapper(labels_dp, getitem(1))
labels_dp: IterDataPipe = SequenceIterator(labels_dp)
labels_dp = Enumerator(labels_dp)
labels_dp = Shuffler(labels_dp, buffer_size=INFINITE_BUFFER_SIZE)
images_dp: IterDataPipe = Mapper(images_dp, getitem(1))
images_dp: IterDataPipe = SequenceIterator(images_dp)
images_dp = Enumerator(images_dp)
dp = KeyZipper(labels_dp, images_dp, getitem(0), buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
@property
@abc.abstractmethod
def _meta_file_name(self) -> str:
pass
@property
@abc.abstractmethod
def _categories_key(self) -> str:
pass
def _is_meta_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name == self._meta_file_name
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp) dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_meta_file) dp: IterDataPipe = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp: IterDataPipe = Mapper(dp, self._unpickle) dp: IterDataPipe = Mapper(dp, self._unpickle)
categories = next(iter(dp))[self._categories_key] categories = next(iter(dp))[self._CATEGORIES_KEY]
create_categories_file(HERE, self.name, categories) create_categories_file(HERE, self.name, categories)
class Cifar10(_CifarBase): class Cifar10(_CifarBase):
_LABELS_KEY = "labels"
_META_FILE_NAME = "batches.meta"
_CATEGORIES_KEY = "label_names"
def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name.startswith("data" if config.split == "train" else "test")
@property @property
def info(self) -> DatasetInfo: def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
...@@ -149,29 +130,16 @@ class Cifar10(_CifarBase): ...@@ -149,29 +130,16 @@ class Cifar10(_CifarBase):
) )
] ]
def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name.startswith("data" if config.split == "train" else "test")
def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data
if key == "data":
return 0
elif key == "labels":
return 1
else:
return None
@property
def _meta_file_name(self) -> str:
return "batches.meta"
@property class Cifar100(_CifarBase):
def _categories_key(self) -> str: _LABELS_KEY = "fine_labels"
return "label_names" _META_FILE_NAME = "meta"
_CATEGORIES_KEY = "fine_label_names"
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name == config.split
class Cifar100(_CifarBase):
@property @property
def info(self) -> DatasetInfo: def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
...@@ -192,27 +160,6 @@ class Cifar100(_CifarBase): ...@@ -192,27 +160,6 @@ class Cifar100(_CifarBase):
) )
] ]
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name == config.split
def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data
if key == "data":
return 0
elif key == "fine_labels":
return 1
else:
return None
@property
def _meta_file_name(self) -> str:
return "meta"
@property
def _categories_key(self) -> str:
return "fine_label_names"
if __name__ == "__main__": if __name__ == "__main__":
from torchvision.prototype.datasets import home from torchvision.prototype.datasets import home
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment