Unverified Commit 5c9c8352 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add DTD dataset (#5115)



* add DTD as prototype dataset

* add old style dataset

* add test for old dataset

* fix tests for windows

* add dataset to docs

* remove properties and use pathlib

* Apply suggestions from code review
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* fold -> partition
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent df628c49
...@@ -38,6 +38,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas ...@@ -38,6 +38,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
Cityscapes Cityscapes
CocoCaptions CocoCaptions
CocoDetection CocoDetection
DTD
EMNIST EMNIST
FakeData FakeData
FashionMNIST FashionMNIST
......
...@@ -2205,5 +2205,41 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2205,5 +2205,41 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase):
return len(sampled_classes * n_samples_per_class) return len(sampled_classes * n_samples_per_class)
class DTDTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DTD
FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "test", "val"),
# There is no need to test the whole matrix here, since each fold is treated exactly the same
partition=(1, 5, 10),
)
def inject_fake_data(self, tmpdir: str, config):
data_folder = pathlib.Path(tmpdir) / "dtd" / "dtd"
num_images_per_class = 3
image_folder = data_folder / "images"
image_files = []
for cls in ("banded", "marbled", "zigzagged"):
image_files.extend(
datasets_utils.create_image_folder(
image_folder,
cls,
file_name_fn=lambda idx: f"{cls}_{idx:04d}.jpg",
num_examples=num_images_per_class,
)
)
meta_folder = data_folder / "labels"
meta_folder.mkdir()
image_ids = [str(path.relative_to(path.parents[1])).replace(os.sep, "/") for path in image_files]
image_ids_in_config = random.choices(image_ids, k=len(image_files) // 2)
with open(meta_folder / f"{config['split']}{config['partition']}.txt", "w") as file:
file.write("\n".join(image_ids_in_config) + "\n")
return len(image_ids_in_config)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -4,6 +4,7 @@ from .celeba import CelebA ...@@ -4,6 +4,7 @@ from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
from .cityscapes import Cityscapes from .cityscapes import Cityscapes
from .coco import CocoCaptions, CocoDetection from .coco import CocoCaptions, CocoDetection
from .dtd import DTD
from .fakedata import FakeData from .fakedata import FakeData
from .flickr import Flickr8k, Flickr30k from .flickr import Flickr8k, Flickr30k
from .folder import ImageFolder, DatasetFolder from .folder import ImageFolder, DatasetFolder
...@@ -79,4 +80,5 @@ __all__ = ( ...@@ -79,4 +80,5 @@ __all__ = (
"FlyingThings3D", "FlyingThings3D",
"HD1K", "HD1K",
"Food101", "Food101",
"DTD",
) )
import os
import pathlib
from typing import Optional, Callable
import PIL.Image
from .utils import verify_str_arg, download_and_extract_archive
from .vision import VisionDataset
class DTD(VisionDataset):
"""`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_.
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
.. note::
The partition only changes which split each image belongs to. Thus, regardless of the selected
partition, combining all splits will result in all images.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""
_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
_MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1"
def __init__(
self,
root: str,
split: str = "train",
partition: int = 1,
download: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
if not isinstance(partition, int) and not (1 <= partition <= 10):
raise ValueError(
f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, "
f"but got {partition} instead"
)
self._partition = partition
super().__init__(root, transform=transform, target_transform=target_transform)
self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower()
self._data_folder = self._base_folder / "dtd"
self._meta_folder = self._data_folder / "labels"
self._images_folder = self._data_folder / "images"
if download:
self._download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
self._image_files = []
classes = []
with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file:
for line in file:
cls, name = line.strip().split("/")
self._image_files.append(self._images_folder.joinpath(cls, name))
classes.append(cls)
self.classes = sorted(set(classes))
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
self._labels = [self.class_to_idx[cls] for cls in classes]
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx):
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
def extra_repr(self) -> str:
return f"split={self._split}, partition={self._partition}"
def _check_exists(self) -> bool:
return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder)
def _download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5)
...@@ -2,6 +2,7 @@ from .caltech import Caltech101, Caltech256 ...@@ -2,6 +2,7 @@ from .caltech import Caltech101, Caltech256
from .celeba import CelebA from .celeba import CelebA
from .cifar import Cifar10, Cifar100 from .cifar import Cifar10, Cifar100
from .coco import Coco from .coco import Coco
from .dtd import DTD
from .imagenet import ImageNet from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .sbd import SBD from .sbd import SBD
......
banded
blotchy
braided
bubbly
bumpy
chequered
cobwebbed
cracked
crosshatched
crystalline
dotted
fibrous
flecked
freckled
frilly
gauzy
grid
grooved
honeycombed
interlaced
knitted
lacelike
lined
marbled
matted
meshed
paisley
perforated
pitted
pleated
polka-dotted
porous
potholed
scaly
smeared
spiralled
sprinkled
stained
stratified
striped
studded
swirly
veined
waffled
woven
wrinkled
zigzagged
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
Filter,
IterKeyZipper,
Demultiplexer,
LineReader,
CSVParser,
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
hint_sharding,
path_comparator,
getitem,
)
from torchvision.prototype.features import Label
class DTD(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"dtd",
type=DatasetType.IMAGE,
homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/",
valid_options=dict(
split=("train", "test", "val"),
fold=tuple(str(fold) for fold in range(1, 11)),
),
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
archive = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz",
sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205",
decompress=True,
)
return [archive]
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.parent.name == "labels":
if path.name == "labels_joint_anno.txt":
return 1
return 0
elif path.parents[1].name == "images":
return 2
else:
return None
def _image_key_fn(self, data: Tuple[str, Any]) -> str:
path = pathlib.Path(data[0])
return str(path.relative_to(path.parents[1]))
def _collate_and_decode_sample(
self,
data: Tuple[Tuple[str, List[str]], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
(_, joint_categories_data), image_data = data
_, *joint_categories = joint_categories_data
path, buffer = image_data
category = pathlib.Path(path).parent.name
return dict(
joint_categories={category for category in joint_categories if category},
label=Label(self.info.categories.index(category), category=category),
path=path,
image=decoder(buffer) if decoder else buffer,
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
splits_dp, joint_categories_dp, images_dp = Demultiplexer(
archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)
splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt"))
splits_dp = LineReader(splits_dp, decode=True, return_path=False)
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
splits_dp = hint_sharding(splits_dp)
joint_categories_dp = CSVParser(joint_categories_dp, delimiter=" ")
dp = IterKeyZipper(
splits_dp,
joint_categories_dp,
key_fn=getitem(),
ref_key_fn=getitem(0),
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = IterKeyZipper(
dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=self._image_key_fn,
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def _filter_images(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == 2
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
dp = Filter(dp, self._filter_images)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment