Unverified Commit 49468279 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add support for PCAM dataset (#5203)



* Add support for PCAM dataset

* mypy

* Apply suggestions from code review
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Remove classes and class_to_idx attributes

* Use _decompress
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 5e56575e
...@@ -9,6 +9,7 @@ dependencies: ...@@ -9,6 +9,7 @@ dependencies:
- libpng - libpng
- jpeg - jpeg
- ca-certificates - ca-certificates
- h5py
- pip: - pip:
- future - future
- pillow >=5.3.0, !=8.3.* - pillow >=5.3.0, !=8.3.*
......
...@@ -9,6 +9,7 @@ dependencies: ...@@ -9,6 +9,7 @@ dependencies:
- libpng - libpng
- jpeg - jpeg
- ca-certificates - ca-certificates
- h5py
- pip: - pip:
- future - future
- pillow >=5.3.0, !=8.3.* - pillow >=5.3.0, !=8.3.*
......
...@@ -66,6 +66,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas ...@@ -66,6 +66,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
MNIST MNIST
Omniglot Omniglot
OxfordIIITPet OxfordIIITPet
PCAM
PhotoTour PhotoTour
Places365 Places365
QMNIST QMNIST
......
...@@ -61,6 +61,7 @@ class LazyImporter: ...@@ -61,6 +61,7 @@ class LazyImporter:
"requests", "requests",
"scipy.io", "scipy.io",
"scipy.sparse", "scipy.sparse",
"h5py",
) )
def __init__(self): def __init__(self):
......
...@@ -2577,5 +2577,28 @@ class Flowers102TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2577,5 +2577,28 @@ class Flowers102TestCase(datasets_utils.ImageDatasetTestCase):
return num_images_per_split[config["split"]] return num_images_per_split[config["split"]]
class PCAMTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.PCAM
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
REQUIRED_PACKAGES = ("h5py",)
def inject_fake_data(self, tmpdir: str, config):
base_folder = pathlib.Path(tmpdir) / "pcam"
base_folder.mkdir()
num_images = {"train": 2, "test": 3, "val": 4}[config["split"]]
images_file = datasets.PCAM._FILES[config["split"]]["images"][0]
with datasets_utils.lazy_importer.h5py.File(str(base_folder / images_file), "w") as f:
f["x"] = np.random.randint(0, 256, size=(num_images, 10, 10, 3), dtype=np.uint8)
targets_file = datasets.PCAM._FILES[config["split"]]["targets"][0]
with datasets_utils.lazy_importer.h5py.File(str(base_folder / targets_file), "w") as f:
f["y"] = np.random.randint(0, 2, size=(num_images, 1, 1, 1), dtype=np.uint8)
return num_images
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -25,6 +25,7 @@ from .lsun import LSUN, LSUNClass ...@@ -25,6 +25,7 @@ from .lsun import LSUN, LSUNClass
from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST, QMNIST from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST, QMNIST
from .omniglot import Omniglot from .omniglot import Omniglot
from .oxford_iiit_pet import OxfordIIITPet from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM
from .phototour import PhotoTour from .phototour import PhotoTour
from .places365 import Places365 from .places365 import Places365
from .sbd import SBDataset from .sbd import SBDataset
......
...@@ -27,8 +27,8 @@ class OxfordIIITPet(VisionDataset): ...@@ -27,8 +27,8 @@ class OxfordIIITPet(VisionDataset):
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``. version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/dtd``. If download (bool, optional): If True, downloads the dataset from the internet and puts it into
dataset is already downloaded, it is not downloaded again. ``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again.
""" """
_RESOURCES = ( _RESOURCES = (
......
import pathlib
from typing import Any, Callable, Optional, Tuple
from PIL import Image
from .utils import download_file_from_google_drive, _decompress, verify_str_arg
from .vision import VisionDataset
class PCAM(VisionDataset):
"""`PCAM Dataset <https://github.com/basveeling/pcam>`_.
The PatchCamelyon dataset is a binary classification dataset with 327,680
color images (96px x 96px), extracted from histopathologic scans of lymph node
sections. Each image is annotated with a binary label indicating presence of
metastatic tissue.
This dataset requires the ``h5py`` package which you can install with ``pip install h5py``.
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"test"`` or ``"val"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
dataset is already downloaded, it is not downloaded again.
"""
_FILES = {
"train": {
"images": (
"camelyonpatch_level_2_split_train_x.h5", # Data file name
"1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2", # Google Drive ID
"1571f514728f59376b705fc836ff4b63", # md5 hash
),
"targets": (
"camelyonpatch_level_2_split_train_y.h5",
"1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG",
"35c2d7259d906cfc8143347bb8e05be7",
),
},
"test": {
"images": (
"camelyonpatch_level_2_split_test_x.h5",
"1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_",
"d5b63470df7cfa627aeec8b9dc0c066e",
),
"targets": (
"camelyonpatch_level_2_split_test_y.h5",
"17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP",
"2b85f58b927af9964a4c15b8f7e8f179",
),
},
"val": {
"images": (
"camelyonpatch_level_2_split_valid_x.h5",
"1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3",
"d8c2d60d490dbd479f8199bdfa0cf6ec",
),
"targets": (
"camelyonpatch_level_2_split_valid_y.h5",
"1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO",
"60a7035772fbdb7f34eb86d4420cf66a",
),
},
}
def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
):
try:
import h5py # type: ignore[import]
self.h5py = h5py
except ImportError:
raise RuntimeError(
"h5py is not found. This dataset needs to have h5py installed: please run pip install h5py"
)
self._split = verify_str_arg(split, "split", ("train", "test", "val"))
super().__init__(root, transform=transform, target_transform=target_transform)
self._base_folder = pathlib.Path(self.root) / "pcam"
if download:
self._download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
def __len__(self) -> int:
images_file = self._FILES[self._split]["images"][0]
with self.h5py.File(self._base_folder / images_file) as images_data:
return images_data["x"].shape[0]
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
images_file = self._FILES[self._split]["images"][0]
with self.h5py.File(self._base_folder / images_file) as images_data:
image = Image.fromarray(images_data["x"][idx]).convert("RGB")
targets_file = self._FILES[self._split]["targets"][0]
with self.h5py.File(self._base_folder / targets_file) as targets_data:
target = int(targets_data["y"][idx, 0, 0, 0]) # shape is [num_images, 1, 1, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
target = self.target_transform(target)
return image, target
def _check_exists(self) -> bool:
images_file = self._FILES[self._split]["images"][0]
targets_file = self._FILES[self._split]["targets"][0]
return all(self._base_folder.joinpath(h5_file).exists() for h5_file in (images_file, targets_file))
def _download(self) -> None:
if self._check_exists():
return
for file_name, file_id, md5 in self._FILES[self._split].values():
archive_name = file_name + ".gz"
download_file_from_google_drive(file_id, str(self._base_folder), filename=archive_name, md5=md5)
_decompress(str(self._base_folder / archive_name))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment