"vscode:/vscode.git/clone" did not exist on "f2f085bf099c4c31bc6f09c21844b2d57dabcb87"
Unverified Commit 693a4632 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add PCAM dataset to prototype area (#5286)



* Add PCAM dataset to prototype area

* use BytesIO instead of writing file to disk

* Apply suggestions from code review
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* remove noqa

* typos

* Use _Resource namedtuple

* Add h5py to unittest_prototype job

* use .item() on target

* Forgot to call regenerate.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent ebbead24
......@@ -351,7 +351,7 @@ jobs:
- install_torchvision
- install_prototype_dependencies
- pip_install:
args: scipy pycocotools
args: scipy pycocotools h5py
descr: Install optional dependencies
- run:
name: Enable prototype tests
......
......@@ -351,7 +351,7 @@ jobs:
- install_torchvision
- install_prototype_dependencies
- pip_install:
args: scipy pycocotools
args: scipy pycocotools h5py
descr: Install optional dependencies
- run:
name: Enable prototype tests
......
......@@ -2,6 +2,7 @@ import collections.abc
import csv
import functools
import gzip
import io
import itertools
import json
import lzma
......@@ -1312,3 +1313,30 @@ def svhn(info, root, config):
},
)
return num_samples
@register_mock
def pcam(info, root, config):
import h5py
num_images = {"train": 2, "test": 3, "val": 4}[config.split]
split = "valid" if config.split == "val" else config.split
images_io = io.BytesIO()
with h5py.File(images_io, "w") as f:
f["x"] = np.random.randint(0, 256, size=(num_images, 10, 10, 3), dtype=np.uint8)
targets_io = io.BytesIO()
with h5py.File(targets_io, "w") as f:
f["y"] = np.random.randint(0, 2, size=(num_images, 1, 1, 1), dtype=np.uint8)
# Create .gz compressed files
images_file = root / f"camelyonpatch_level_2_split_{split}_x.h5.gz"
targets_file = root / f"camelyonpatch_level_2_split_{split}_y.h5.gz"
for compressed_file_name, uncompressed_file_io in ((images_file, images_io), (targets_file, targets_io)):
compressed_data = gzip.compress(uncompressed_file_io.getbuffer())
with open(compressed_file_name, "wb") as compressed_file:
compressed_file.write(compressed_data)
return num_images
......@@ -10,6 +10,7 @@ from .gtsrb import GTSRB
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .oxford_iiit_pet import OxfordIITPet
from .pcam import PCAM
from .sbd import SBD
from .semeion import SEMEION
from .svhn import SVHN
......
import io
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator
import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper
from torchvision.prototype import features
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource,
DatasetType,
GDriveResource,
)
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import Label
class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]):
def __init__(
self,
datapipe: IterDataPipe[Tuple[str, io.IOBase]],
key: Optional[str] = None, # Note: this key thing might be very specific to the PCAM dataset
) -> None:
self.datapipe = datapipe
self.key = key
def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]:
import h5py
for _, handle in self.datapipe:
with h5py.File(handle) as data:
if self.key is not None:
data = data[self.key]
yield from data
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))
class PCAM(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"pcam",
type=DatasetType.RAW,
homepage="https://github.com/basveeling/pcam",
categories=2,
valid_options=dict(split=("train", "test", "val")),
dependencies=["h5py"],
)
_RESOURCES = {
"train": (
_Resource( # Images
file_name="camelyonpatch_level_2_split_train_x.h5.gz",
gdrive_id="1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2",
sha256="d619e741468a7ab35c7e4a75e6821b7e7e6c9411705d45708f2a0efc8960656c",
),
_Resource( # Targets
file_name="camelyonpatch_level_2_split_train_y.h5.gz",
gdrive_id="1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG",
sha256="b74126d2c01b20d3661f9b46765d29cf4e4fba6faba29c8e0d09d406331ab75a",
),
),
"test": (
_Resource( # Images
file_name="camelyonpatch_level_2_split_test_x.h5.gz",
gdrive_id="1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_",
sha256="79174c2201ad521602a5888be8f36ee10875f37403dd3f2086caf2182ef87245",
),
_Resource( # Targets
file_name="camelyonpatch_level_2_split_test_y.h5.gz",
gdrive_id="17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP",
sha256="0a522005fccc8bbd04c5a117bfaf81d8da2676f03a29d7499f71d0a0bd6068ef",
),
),
"val": (
_Resource( # Images
file_name="camelyonpatch_level_2_split_valid_x.h5.gz",
gdrive_id="1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3",
sha256="f82ee1670d027b4ec388048d9eabc2186b77c009655dae76d624c0ecb053ccb2",
),
_Resource( # Targets
file_name="camelyonpatch_level_2_split_valid_y.h5.gz",
gdrive_id="1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO",
sha256="ce1ae30f08feb468447971cfd0472e7becd0ad96d877c64120c72571439ae48c",
),
),
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [ # = [images resource, targets resource]
GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, decompress=True)
for file_name, gdrive_id, sha256 in self._RESOURCES[config.split]
]
def _collate_and_decode(self, data: Tuple[Any, Any]) -> Dict[str, Any]:
image, target = data # They're both numpy arrays at this point
return {
"image": features.Image(image),
"label": Label(target.item()),
}
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
images_dp, targets_dp = resource_dps
images_dp = PCAMH5Reader(images_dp, key="x")
targets_dp = PCAMH5Reader(targets_dp, key="y")
dp = Zipper(images_dp, targets_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment