Unverified Commit f948d797 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add CLEVR dataset (#5130)



* add prototype dataset

* add old-style dataset

* appease mypy

* simplify prototype scenes

* Update torchvision/datasets/clevr.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 7120024d
...@@ -2325,5 +2325,37 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2325,5 +2325,37 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
return total_number_of_examples return total_number_of_examples
class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CLEVRClassification
FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
def inject_fake_data(self, tmpdir, config):
data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0"
images_folder = data_folder / "images"
image_files = datasets_utils.create_image_folder(
images_folder, config["split"], lambda idx: f"CLEVR_{config['split']}_{idx:06d}.png", num_examples=5
)
scenes_folder = data_folder / "scenes"
scenes_folder.mkdir()
if config["split"] != "test":
with open(scenes_folder / f"CLEVR_{config['split']}_scenes.json", "w") as file:
json.dump(
dict(
info=dict(),
scenes=[
dict(image_filename=image_file.name, objects=[dict()] * int(torch.randint(10, ())))
for image_file in image_files
],
),
file,
)
return len(image_files)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -3,6 +3,7 @@ from .caltech import Caltech101, Caltech256 ...@@ -3,6 +3,7 @@ from .caltech import Caltech101, Caltech256
from .celeba import CelebA from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
from .cityscapes import Cityscapes from .cityscapes import Cityscapes
from .clevr import CLEVRClassification
from .coco import CocoCaptions, CocoDetection from .coco import CocoCaptions, CocoDetection
from .dtd import DTD from .dtd import DTD
from .fakedata import FakeData from .fakedata import FakeData
...@@ -85,4 +86,5 @@ __all__ = ( ...@@ -85,4 +86,5 @@ __all__ = (
"DTD", "DTD",
"FER2013", "FER2013",
"GTSRB", "GTSRB",
"CLEVRClassification",
) )
import json
import pathlib
from typing import Any, Callable, Optional, Tuple, List
from urllib.parse import urlparse
from PIL import Image
from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
class CLEVRClassification(VisionDataset):
"""`CLEVR <https://cs.stanford.edu/people/jcjohns/clevr/>`_ classification dataset.
The number of objects in a scene are used as label.
Args:
root (string): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
set to True.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in them target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If
dataset is already downloaded, it is not downloaded again.
"""
_URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
_MD5 = "b11922020e72d0cd9154779b2d3d07d2"
def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
super().__init__(root, transform=transform, target_transform=target_transform)
self._base_folder = pathlib.Path(self.root) / "clevr"
self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem
if download:
self._download()
if not self._check_exists():
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
self._image_files = sorted(self._data_folder.joinpath("images", self._split).glob("*"))
self._labels: List[Optional[int]]
if self._split != "test":
with open(self._data_folder / "scenes" / f"CLEVR_{self._split}_scenes.json") as file:
content = json.load(file)
num_objects = {scene["image_filename"]: len(scene["objects"]) for scene in content["scenes"]}
self._labels = [num_objects[image_file.name] for image_file in self._image_files]
else:
self._labels = [None] * len(self._image_files)
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file = self._image_files[idx]
label = self._labels[idx]
image = Image.open(image_file).convert("RGB")
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
def _check_exists(self) -> bool:
return self._data_folder.exists() and self._data_folder.is_dir()
def _download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(self._URL, str(self._base_folder), md5=self._MD5)
def extra_repr(self) -> str:
return f"split={self._split}"
from .caltech import Caltech101, Caltech256 from .caltech import Caltech101, Caltech256
from .celeba import CelebA from .celeba import CelebA
from .cifar import Cifar10, Cifar100 from .cifar import Cifar10, Cifar100
from .clevr import CLEVR
from .coco import Coco from .coco import Coco
from .dtd import DTD from .dtd import DTD
from .fer2013 import FER2013 from .fer2013 import FER2013
......
import functools
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
hint_sharding,
hint_shuffling,
path_comparator,
path_accessor,
getitem,
)
from torchvision.prototype.features import Label
class CLEVR(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"clevr",
type=DatasetType.IMAGE,
homepage="https://cs.stanford.edu/people/jcjohns/clevr/",
valid_options=dict(split=("train", "val", "test")),
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
archive = HttpResource(
"https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip",
sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1",
)
return [archive]
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.parents[1].name == "images":
return 0
elif path.parent.name == "scenes":
return 1
else:
return None
def _filter_scene_anns(self, data: Tuple[str, Any]) -> bool:
key, _ = data
return key == "scenes"
def _add_empty_anns(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[str, io.IOBase], None]:
return data, None
def _collate_and_decode_sample(
self,
data: Tuple[Tuple[str, io.IOBase], Optional[Dict[str, Any]]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
image_data, scenes_data = data
path, buffer = image_data
return dict(
path=path,
image=decoder(buffer) if decoder else buffer,
label=Label(len(scenes_data["objects"])) if scenes_data else None,
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
images_dp, scenes_dp = Demultiplexer(
archive_dp,
2,
self._classify_archive,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
images_dp = Filter(images_dp, path_comparator("parent.name", config.split))
images_dp = hint_sharding(images_dp)
images_dp = hint_shuffling(images_dp)
if config.split != "test":
scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json"))
scenes_dp = JsonParser(scenes_dp)
scenes_dp = Mapper(scenes_dp, getitem(1, "scenes"))
scenes_dp = UnBatcher(scenes_dp)
dp = IterKeyZipper(
images_dp,
scenes_dp,
key_fn=path_accessor("name"),
ref_key_fn=getitem("image_filename"),
buffer_size=INFINITE_BUFFER_SIZE,
)
else:
dp = Mapper(images_dp, self._add_empty_anns)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
...@@ -108,7 +108,7 @@ class Enumerator(IterDataPipe[Tuple[int, D]]): ...@@ -108,7 +108,7 @@ class Enumerator(IterDataPipe[Tuple[int, D]]):
yield from enumerate(self.datapipe, self.start) yield from enumerate(self.datapipe, self.start)
def _getitem_closure(obj: Any, *, items: Tuple[Any, ...]) -> Any: def _getitem_closure(obj: Any, *, items: Sequence[Any]) -> Any:
for item in items: for item in items:
obj = obj[item] obj = obj[item]
return obj return obj
...@@ -118,8 +118,14 @@ def getitem(*items: Any) -> Callable[[Any], Any]: ...@@ -118,8 +118,14 @@ def getitem(*items: Any) -> Callable[[Any], Any]:
return functools.partial(_getitem_closure, items=items) return functools.partial(_getitem_closure, items=items)
def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any:
for attr in attrs:
obj = getattr(obj, attr)
return obj
def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> D: def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> D:
return cast(D, getattr(path, name)) return cast(D, _getattr_closure(path, attrs=name.split(".")))
def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D: def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment