Unverified Commit 5cc477c7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

use fn helpers in CIFAR prototype (#4543)

parent 4b6fc6b8
......@@ -32,6 +32,7 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
image_buffer_from_array,
Enumerator,
getitem,
)
__all__ = ["Cifar10", "Cifar100"]
......@@ -54,12 +55,6 @@ class _CifarBase(Dataset):
_, file = data
return pickle.load(file, encoding="latin1")
def _remove_data_dict_key(self, data: Tuple[str, D]) -> D:
return data[1]
def _key_fn(self, data: Tuple[int, Any]) -> int:
return data[0]
def _collate_and_decode(
self,
data: Tuple[Tuple[int, int], Tuple[int, np.ndarray]],
......@@ -101,16 +96,16 @@ class _CifarBase(Dataset):
buffer_size=INFINITE_BUFFER_SIZE,
)
labels_dp: IterDataPipe = Mapper(labels_dp, self._remove_data_dict_key)
labels_dp: IterDataPipe = Mapper(labels_dp, getitem(1))
labels_dp: IterDataPipe = SequenceIterator(labels_dp)
labels_dp = Enumerator(labels_dp)
labels_dp = Shuffler(labels_dp, buffer_size=INFINITE_BUFFER_SIZE)
images_dp: IterDataPipe = Mapper(images_dp, self._remove_data_dict_key)
images_dp: IterDataPipe = Mapper(images_dp, getitem(1))
images_dp: IterDataPipe = SequenceIterator(images_dp)
images_dp = Enumerator(images_dp)
dp = KeyZipper(labels_dp, images_dp, self._key_fn, buffer_size=INFINITE_BUFFER_SIZE)
dp = KeyZipper(labels_dp, images_dp, getitem(0), buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
@property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment