Unverified Commit 693a4632 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add PCAM dataset to prototype area (#5286)



* Add PCAM dataset to prototype area

* use BytesIO instead of writing file to disk

* Apply suggestions from code review
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* remove noqa

* typos

* Use _Resource namedtuple

* Add h5py to unittest_prototype job

* use .item() on target

* Forgot to call regenerate.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent ebbead24
...@@ -351,7 +351,7 @@ jobs: ...@@ -351,7 +351,7 @@ jobs:
- install_torchvision - install_torchvision
- install_prototype_dependencies - install_prototype_dependencies
- pip_install: - pip_install:
args: scipy pycocotools args: scipy pycocotools h5py
descr: Install optional dependencies descr: Install optional dependencies
- run: - run:
name: Enable prototype tests name: Enable prototype tests
......
...@@ -351,7 +351,7 @@ jobs: ...@@ -351,7 +351,7 @@ jobs:
- install_torchvision - install_torchvision
- install_prototype_dependencies - install_prototype_dependencies
- pip_install: - pip_install:
args: scipy pycocotools args: scipy pycocotools h5py
descr: Install optional dependencies descr: Install optional dependencies
- run: - run:
name: Enable prototype tests name: Enable prototype tests
......
...@@ -2,6 +2,7 @@ import collections.abc ...@@ -2,6 +2,7 @@ import collections.abc
import csv import csv
import functools import functools
import gzip import gzip
import io
import itertools import itertools
import json import json
import lzma import lzma
...@@ -1312,3 +1313,30 @@ def svhn(info, root, config): ...@@ -1312,3 +1313,30 @@ def svhn(info, root, config):
}, },
) )
return num_samples return num_samples
@register_mock
def pcam(info, root, config):
import h5py
num_images = {"train": 2, "test": 3, "val": 4}[config.split]
split = "valid" if config.split == "val" else config.split
images_io = io.BytesIO()
with h5py.File(images_io, "w") as f:
f["x"] = np.random.randint(0, 256, size=(num_images, 10, 10, 3), dtype=np.uint8)
targets_io = io.BytesIO()
with h5py.File(targets_io, "w") as f:
f["y"] = np.random.randint(0, 2, size=(num_images, 1, 1, 1), dtype=np.uint8)
# Create .gz compressed files
images_file = root / f"camelyonpatch_level_2_split_{split}_x.h5.gz"
targets_file = root / f"camelyonpatch_level_2_split_{split}_y.h5.gz"
for compressed_file_name, uncompressed_file_io in ((images_file, images_io), (targets_file, targets_io)):
compressed_data = gzip.compress(uncompressed_file_io.getbuffer())
with open(compressed_file_name, "wb") as compressed_file:
compressed_file.write(compressed_data)
return num_images
...@@ -10,6 +10,7 @@ from .gtsrb import GTSRB ...@@ -10,6 +10,7 @@ from .gtsrb import GTSRB
from .imagenet import ImageNet from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .oxford_iiit_pet import OxfordIITPet from .oxford_iiit_pet import OxfordIITPet
from .pcam import PCAM
from .sbd import SBD from .sbd import SBD
from .semeion import SEMEION from .semeion import SEMEION
from .svhn import SVHN from .svhn import SVHN
......
import io
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator
import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper
from torchvision.prototype import features
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource,
DatasetType,
GDriveResource,
)
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import Label
class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]):
def __init__(
self,
datapipe: IterDataPipe[Tuple[str, io.IOBase]],
key: Optional[str] = None, # Note: this key thing might be very specific to the PCAM dataset
) -> None:
self.datapipe = datapipe
self.key = key
def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]:
import h5py
for _, handle in self.datapipe:
with h5py.File(handle) as data:
if self.key is not None:
data = data[self.key]
yield from data
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))
class PCAM(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"pcam",
type=DatasetType.RAW,
homepage="https://github.com/basveeling/pcam",
categories=2,
valid_options=dict(split=("train", "test", "val")),
dependencies=["h5py"],
)
_RESOURCES = {
"train": (
_Resource( # Images
file_name="camelyonpatch_level_2_split_train_x.h5.gz",
gdrive_id="1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2",
sha256="d619e741468a7ab35c7e4a75e6821b7e7e6c9411705d45708f2a0efc8960656c",
),
_Resource( # Targets
file_name="camelyonpatch_level_2_split_train_y.h5.gz",
gdrive_id="1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG",
sha256="b74126d2c01b20d3661f9b46765d29cf4e4fba6faba29c8e0d09d406331ab75a",
),
),
"test": (
_Resource( # Images
file_name="camelyonpatch_level_2_split_test_x.h5.gz",
gdrive_id="1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_",
sha256="79174c2201ad521602a5888be8f36ee10875f37403dd3f2086caf2182ef87245",
),
_Resource( # Targets
file_name="camelyonpatch_level_2_split_test_y.h5.gz",
gdrive_id="17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP",
sha256="0a522005fccc8bbd04c5a117bfaf81d8da2676f03a29d7499f71d0a0bd6068ef",
),
),
"val": (
_Resource( # Images
file_name="camelyonpatch_level_2_split_valid_x.h5.gz",
gdrive_id="1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3",
sha256="f82ee1670d027b4ec388048d9eabc2186b77c009655dae76d624c0ecb053ccb2",
),
_Resource( # Targets
file_name="camelyonpatch_level_2_split_valid_y.h5.gz",
gdrive_id="1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO",
sha256="ce1ae30f08feb468447971cfd0472e7becd0ad96d877c64120c72571439ae48c",
),
),
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [ # = [images resource, targets resource]
GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, decompress=True)
for file_name, gdrive_id, sha256 in self._RESOURCES[config.split]
]
def _collate_and_decode(self, data: Tuple[Any, Any]) -> Dict[str, Any]:
image, target = data # They're both numpy arrays at this point
return {
"image": features.Image(image),
"label": Label(target.item()),
}
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
images_dp, targets_dp = resource_dps
images_dp = PCAMH5Reader(images_dp, key="x")
targets_dp = PCAMH5Reader(targets_dp, key="y")
dp = Zipper(images_dp, targets_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment