Unverified Commit 80b0e35c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

FER2013 dataset (#5120)



* add prototype dataset

* add old style dataset

* Apply suggestions from code review
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* refactor integrity check
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 5c9c8352
......@@ -42,6 +42,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
EMNIST
FakeData
FashionMNIST
FER2013
Flickr8k
Flickr30k
FlyingChairs
......
import bz2
import contextlib
import csv
import io
import itertools
import json
......@@ -2241,5 +2242,38 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase):
return len(image_ids_in_config)
class FER2013TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FER2013
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))
def inject_fake_data(self, tmpdir, config):
base_folder = os.path.join(tmpdir, "fer2013")
os.makedirs(base_folder)
num_samples = 5
with open(os.path.join(base_folder, f"{config['split']}.csv"), "w", newline="") as file:
writer = csv.DictWriter(
file,
fieldnames=("emotion", "pixels") if config["split"] == "train" else ("pixels",),
quoting=csv.QUOTE_NONNUMERIC,
quotechar='"',
)
writer.writeheader()
for _ in range(num_samples):
row = dict(
pixels=" ".join(
str(pixel) for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist()
)
)
if config["split"] == "train":
row["emotion"] = str(int(torch.randint(0, 7, ())))
writer.writerow(row)
return num_samples
if __name__ == "__main__":
unittest.main()
......@@ -6,6 +6,7 @@ from .cityscapes import Cityscapes
from .coco import CocoCaptions, CocoDetection
from .dtd import DTD
from .fakedata import FakeData
from .fer2013 import FER2013
from .flickr import Flickr8k, Flickr30k
from .folder import ImageFolder, DatasetFolder
from .food101 import Food101
......@@ -81,4 +82,5 @@ __all__ = (
"HD1K",
"Food101",
"DTD",
"FER2013",
)
import csv
import pathlib
from typing import Any, Callable, Optional, Tuple
import torch
from PIL import Image
from .utils import verify_str_arg, check_integrity
from .vision import VisionDataset
class FER2013(VisionDataset):
"""`FER2013
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``root/fer2013`` exists.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""
_RESOURCES = {
"train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
"test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
}
def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
self._split = verify_str_arg(split, "split", self._RESOURCES.keys())
super().__init__(root, transform=transform, target_transform=target_transform)
base_folder = pathlib.Path(self.root) / "fer2013"
file_name, md5 = self._RESOURCES[self._split]
data_file = base_folder / file_name
if not check_integrity(str(data_file), md5=md5):
raise RuntimeError(
f"{file_name} not found in {base_folder} or corrupted. "
f"You can download it from "
f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
)
with open(data_file, "r", newline="") as file:
self._samples = [
(
torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48),
int(row["emotion"]) if "emotion" in row else None,
)
for row in csv.DictReader(file)
]
def __len__(self) -> int:
return len(self._samples)
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_tensor, target = self._samples[idx]
image = Image.fromarray(image_tensor.numpy())
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
target = self.target_transform(target)
return image, target
def extra_repr(self) -> str:
return f"split={self._split}"
......@@ -3,6 +3,7 @@ from .celeba import CelebA
from .cifar import Cifar10, Cifar100
from .coco import Coco
from .dtd import DTD
from .fer2013 import FER2013
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .sbd import SBD
......
import functools
import io
from typing import Any, Callable, Dict, List, Optional, Union, cast
import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource,
DatasetType,
KaggleDownloadResource,
)
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
image_buffer_from_array,
)
from torchvision.prototype.features import Label, Image
class FER2013(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"fer2013",
type=DatasetType.RAW,
homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"),
valid_options=dict(split=("train", "test")),
)
_CHECKSUMS = {
"train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10",
"test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3",
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
archive = KaggleDownloadResource(
cast(str, self.info.homepage),
file_name=f"{config.split}.csv.zip",
sha256=self._CHECKSUMS[config.split],
)
return [archive]
def _collate_and_decode_sample(
self,
data: Dict[str, Any],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
raw_image = torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)
label_id = data.get("emotion")
label_idx = int(label_id) if label_id is not None else None
image: Union[Image, io.BytesIO]
if decoder is raw:
image = Image(raw_image)
else:
image_buffer = image_buffer_from_array(raw_image.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
return dict(
image=image,
label=Label(label_idx, category=self.info.categories[label_idx]) if label_idx is not None else None,
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVDictParser(dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
from . import _internal
from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset
from ._query import SampleQuery
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource
......@@ -176,3 +176,17 @@ class ManualDownloadResource(OnlineResource):
f"Please follow the instructions below and place it in {root}\n\n"
f"{self.instructions}"
)
class KaggleDownloadResource(ManualDownloadResource):
def __init__(self, challenge_url: str, *, file_name: str, **kwargs: Any) -> None:
instructions = "\n".join(
(
"1. Register and login at https://www.kaggle.com",
f"2. Navigate to {challenge_url}",
"3. Click 'Join Competition' and follow the instructions there",
"4. Navigate to the 'Data' tab",
f"5. Select {file_name} in the 'Data Explorer' and click the download button",
)
)
super().__init__(instructions, file_name=file_name, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment