Unverified Commit 8cc6d521 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add prototype dataset for CelebA (#4514)

* add prototype dataset for CelebA

* fix code format

* fix mypy

* hardcode fmtparams

* fix mypy

* replace KeyZipper with Zipper for annotations
parent 493d301e
from .caltech import Caltech101, Caltech256 from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import Cifar10, Cifar100 from .cifar import Cifar10, Cifar100
from .sbd import SBD from .sbd import SBD
from .voc import VOC from .voc import VOC
import csv
import io
from typing import Any, Callable, Dict, List, Optional, Tuple, Mapping, Union
import torch
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import (
Mapper,
Shuffler,
Filter,
ZipArchiveReader,
Zipper,
)
from torchdata.datapipes.iter import KeyZipper
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
GDriveResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor
class CelebACSVParser(IterDataPipe):
def __init__(
self,
datapipe,
*,
has_header,
):
self.datapipe = datapipe
self.has_header = has_header
self._fmtparams = dict(delimiter=" ", skipinitialspace=True)
def __iter__(self):
for _, file in self.datapipe:
file = (line.decode() for line in file)
if self.has_header:
# The first row is skipped, because it only contains the number of samples
next(file)
# Empty field names are filtered out, because some files have an extr white space after the header
# line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(file)], **self._fmtparams)) if name]
# Some files do not include a label for the image ID column
if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id")
for line in csv.DictReader(file, fieldnames=fieldnames, **self._fmtparams):
yield line.pop("image_id"), line
else:
for line in csv.reader(file, **self._fmtparams):
yield line[0], line[1:]
class CelebA(Dataset):
@property
def info(self) -> DatasetInfo:
return DatasetInfo(
"celeba",
type=DatasetType.IMAGE,
homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html",
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
splits = GDriveResource(
"0B7EVK8r0v71pY0NSMzRuSXJEVkk",
sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7",
file_name="list_eval_partition.txt",
)
images = GDriveResource(
"0B7EVK8r0v71pZjFTYXZWM3FlRnM",
sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74",
file_name="img_align_celeba.zip",
)
identities = GDriveResource(
"1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS",
sha256="c6143857c3e2630ac2da9f782e9c1232e5e59be993a9d44e8a7916c78a6158c0",
file_name="identity_CelebA.txt",
)
attributes = GDriveResource(
"0B7EVK8r0v71pblRyaVFSWGxPY0U",
sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0",
file_name="list_attr_celeba.txt",
)
bboxes = GDriveResource(
"0B7EVK8r0v71pbThiMVRxWXZ4dU0",
sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b",
file_name="list_bbox_celeba.txt",
)
landmarks = GDriveResource(
"0B7EVK8r0v71pd0FJY3Blby1HUTQ",
sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b",
file_name="list_landmarks_align_celeba.txt",
)
return [splits, images, identities, attributes, bboxes, landmarks]
_SPLIT_ID_TO_NAME = {
"0": "train",
"1": "valid",
"2": "test",
}
def _filter_split(self, data: Tuple[str, str], *, split):
_, split_id = data
return self._SPLIT_ID_TO_NAME[split_id[0]] == split
def _collate_anns(
self, data: Tuple[Tuple[str, Union[List[str], Mapping[str, str]]], ...]
) -> Tuple[str, Dict[str, Union[List[str], Mapping[str, str]]]]:
(image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data
return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks)
def _collate_and_decode_sample(
self,
data: Tuple[Tuple[str, Tuple[str, List[str]], Tuple[str, io.IOBase]], Tuple[str, Dict[str, Any]]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, _, image_data = split_and_image_data
path, buffer = image_data
_, ann = ann_data
image = decoder(buffer) if decoder else buffer
identity = torch.tensor(int(ann["identity"][0]))
attributes = {attr: value == "1" for attr, value in ann["attributes"].items()}
bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")])
landmarks = {
landmark: torch.tensor((int(ann["landmarks"][f"{landmark}_x"]), int(ann["landmarks"][f"{landmark}_y"])))
for landmark in {key[:-2] for key in ann["landmarks"].keys()}
}
return dict(
path=path,
image=image,
identity=identity,
attributes=attributes,
bbox=bbox,
landmarks=landmarks,
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps
splits_dp = CelebACSVParser(splits_dp, has_header=False)
splits_dp: IterDataPipe = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
images_dp = ZipArchiveReader(images_dp)
anns_dp: IterDataPipe = Zipper(
*[
CelebACSVParser(dp, has_header=has_header)
for dp, has_header in (
(identities_dp, False),
(attributes_dp, True),
(bboxes_dp, True),
(landmarks_dp, True),
)
]
)
anns_dp: IterDataPipe = Mapper(anns_dp, self._collate_anns)
dp = KeyZipper(
splits_dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE,
keep_key=True,
)
dp = KeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment