Unverified Commit 140322f9 authored by Kushashwa Ravi Shrimali's avatar Kushashwa Ravi Shrimali Committed by GitHub
Browse files

Port `semeion` dataset to `prototype` namespace (#4840)



* Port semeion dataset

* Update torchvision/prototype/datasets/_builtin/semeion.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* explicitly convert the image array to torch.uint8

* explicitly convert the image array to torch.uint8
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent e8ceaaf9
......@@ -5,4 +5,5 @@ from .coco import Coco
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .sbd import SBD
from .semeion import SEMEION
from .voc import VOC
import io
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
CSVParser,
)
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, image_buffer_from_array
class SEMEION(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"semeion",
type=DatasetType.RAW,
categories=10,
homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit",
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
archive = HttpResource(
"http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data",
sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1",
)
return [archive]
def _collate_and_decode_sample(
self,
data: Tuple[str, ...],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
image_data = torch.tensor([float(pixel) for pixel in data[:256]], dtype=torch.uint8).reshape(16, 16)
label_data = [int(label) for label in data[256:] if label]
if decoder is raw:
image = image_data.unsqueeze(0)
else:
image_buffer = image_buffer_from_array(image_data.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
label = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label))
category = self.info.categories[label]
return dict(image=image, label=label, category=category)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVParser(dp, delimiter=" ")
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return dp
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment