Unverified Commit fc69c225 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Places365 dataset (#2610)

* initial draft

* [dirty] progress

* remove inheritance from ImageFolder

* add tests

* lint

* fix type hints

* align getitem with other datasets

* remove unused import

* add docstring

* guard existing image folders from overwrite

* add missing entry in docstring

* make fixpath more legible

* add Places365 to docs
parent 3e0f5a6f
......@@ -166,6 +166,13 @@ PhotoTour
:members: __getitem__
:special-members:
Places365
~~~~~~~~~
.. autoclass:: Places365
:members: __getitem__
:special-members:
QMNIST
~~~~~~
......
......@@ -10,6 +10,8 @@ import pickle
import random
from itertools import cycle
from torchvision.io.video import write_video
import unittest.mock
import hashlib
@contextlib.contextmanager
......@@ -312,3 +314,105 @@ def ucf101_root():
for f in file_handles:
f.close()
yield (video_dir, annotations)
@contextlib.contextmanager
def places365_root(split="train-standard", small=False, extract_images=True):
CATEGORIES = (("/a/airfield", 0), ("/a/apartment_building/outdoor", 8), ("/b/badlands", 30))
FILE_LIST = [(f"{idx}.png", idx) for idx in tuple(zip(*CATEGORIES))[1]]
def compute_md5(file):
with open(file, "rb") as fh:
return hashlib.md5(fh.read()).hexdigest()
def make_txt(root, name, cls_or_image_seq):
file = os.path.join(root, name)
with open(file, "w") as fh:
for cls_or_image, idx in cls_or_image_seq:
fh.write(f"{cls_or_image} {idx}\n")
return name, compute_md5(file)
def make_categories_txt(root, name):
return make_txt(root, name, CATEGORIES)
def make_file_list_txt(root, name):
return make_txt(root, name, FILE_LIST)
def make_image(root, name, size):
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(os.path.join(root, name))
def make_tar(root, name, files, remove_files=True):
archive = os.path.join(root, name)
files = [os.path.join(root, file) for file in files]
with tarfile.open(archive, "w") as fh:
for file in files:
fh.add(file, os.path.basename(file))
if remove_files:
for file in files:
os.remove(file)
return name, compute_md5(archive)
def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
return f"{partial}.{attr}"
def mock_class_attribute(stack, attr, new):
mock = unittest.mock.patch(mock_target(attr), new_callable=unittest.mock.PropertyMock, return_value=new)
stack.enter_context(mock)
return mock
def split_to_variant(split):
return "challenge" if split == "train-challenge" else "standard"
def make_devkit_archive(stack, root, split):
variant = split_to_variant(split)
archive = f"filelist_places365-{variant}.tar"
files = []
meta = make_categories_txt(root, "categories_places365.txt")
mock_class_attribute(stack, "_CATEGORIES_META", meta)
files.append(meta[0])
meta = {
split: make_file_list_txt(root, f"places365_{split.replace('-', '_')}.txt")
for split in (f"train-{variant}", "val", "test")
}
mock_class_attribute(stack, "_FILE_LIST_META", meta)
files.extend([item[0] for item in meta.values()])
meta = {variant: make_tar(root, archive, files)}
mock_class_attribute(stack, "_DEVKIT_META", meta)
def make_images_archive(stack, root, split, small):
if split.startswith("train"):
images_dir = f"train_{'256' if small else 'large'}_places365{split_to_variant(split)}"
else:
images_dir = f"{split}_{'256' if small else 'large'}"
archive = f"{images_dir}.tar"
size = (256, 256) if small else (512, random.randint(512, 1024))
imgs = [item[0] for item in FILE_LIST]
for img in imgs:
make_image(root, img, size)
meta = {(split, small): make_tar(root, archive, imgs)}
mock_class_attribute(stack, "_IMAGES_META", meta)
return images_dir
with contextlib.ExitStack() as stack:
with get_tmp_dir() as root:
make_devkit_archive(stack, root, split)
class_to_idx = dict(CATEGORIES)
classes = list(class_to_idx.keys())
data = {"class_to_idx": class_to_idx, "classes": classes}
if extract_images:
images_dir = make_images_archive(stack, root, split, small)
data["imgs"] = [(os.path.join(root, images_dir, file), idx) for file, idx in FILE_LIST]
else:
stack.enter_context(unittest.mock.patch(mock_target("download_images")))
yield root, data
......@@ -9,8 +9,10 @@ from torch._utils_internal import get_file_path_2
import torchvision
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
cityscapes_root, svhn_root, voc_root, ucf101_root
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
try:
......@@ -280,6 +282,104 @@ class Tester(unittest.TestCase):
self.assertEqual(audio.numel(), 0)
self.assertEqual(label, 1)
def test_places365(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val", "test"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"]))
def test_places365_transforms(self):
expected_image = "image"
expected_target = "target"
def transform(image):
return expected_image
def target_transform(target):
return expected_target
with places365_root() as places365:
root, data = places365
dataset = torchvision.datasets.Places365(
root, transform=transform, target_transform=target_transform, download=True
)
actual_image, actual_target = dataset[0]
self.assertEqual(actual_image, expected_image)
self.assertEqual(actual_target, expected_target)
@mock.patch("torchvision.datasets.utils.download_url")
def test_places365_downloadable(self, download_url):
for split, small in itertools.product(("train-standard", "train-challenge", "val", "test"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365
torchvision.datasets.Places365(root, split=split, small=small, download=True)
urls = {call_args[0][0] for call_args in download_url.call_args_list}
for url in urls:
with self.subTest(url=url):
response = urlopen(Request(url, method="HEAD"))
assert response.code == 200, f"Server returned status code {response.code} for {url}."
def test_places365_devkit_download(self):
for split in ("train-standard", "train-challenge", "val", "test"):
with self.subTest(split=split):
with places365_root(split=split) as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, split=split, download=True)
with self.subTest("classes"):
self.assertSequenceEqual(dataset.classes, data["classes"])
with self.subTest("class_to_idx"):
self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"])
with self.subTest("imgs"):
self.assertSequenceEqual(dataset.imgs, data["imgs"])
def test_places365_devkit_no_download(self):
for split in ("train-standard", "train-challenge", "val", "test"):
with self.subTest(split=split):
with places365_root(split=split, extract_images=False) as places365:
root, data = places365
with self.assertRaises(RuntimeError):
torchvision.datasets.Places365(root, split=split, download=False)
def test_places365_images_download(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val", "test"), (False, True)):
with self.subTest(split=split, small=small):
with places365_root(split=split, small=small) as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
assert all(os.path.exists(item[0]) for item in dataset.imgs)
def test_places365_images_download_preexisting(self):
split = "train-standard"
small = False
images_dir = "train_large_places365standard"
with places365_root(split=split, small=small) as places365:
root, data = places365
os.mkdir(os.path.join(root, images_dir))
with self.assertRaises(RuntimeError):
torchvision.datasets.Places365(root, split=split, small=small, download=True)
def test_places365_repr_smoke(self):
with places365_root(extract_images=False) as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, download=True)
self.assertIsInstance(repr(dataset), str)
if __name__ == '__main__':
unittest.main()
......@@ -22,6 +22,7 @@ from .usps import USPS
from .kinetics import Kinetics400
from .hmdb51 import HMDB51
from .ucf101 import UCF101
from .places365 import Places365
__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
......@@ -31,4 +32,4 @@ __all__ = ('LSUN', 'LSUNClass',
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset',
'USPS', 'Kinetics400', 'HMDB51', 'UCF101')
'USPS', 'Kinetics400', 'HMDB51', 'UCF101', 'Places365')
import os
from os import path
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.parse import urljoin
from .folder import default_loader
from .utils import verify_str_arg, check_integrity, download_and_extract_archive
from .vision import VisionDataset
class Places365(VisionDataset):
r"""`Places365 <http://places2.csail.mit.edu/index.html>`_ classification dataset.
Args:
root (string): Root directory of the Places365 dataset.
split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challendge``,
``val``, and ``test``.
small (bool, optional): If ``True``, uses the small images, i. e. resized to 256 x 256 pixels, instead of the
high resolution ones.
download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
downloaded archives are not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
targets (list): The class_index value for each image in the dataset
Raises:
RuntimeError: If ``download is False`` and the meta files, i. e. the devkit, are not present or corrupted.
RuntimeError: If ``download is True`` and the image archive is already extracted.
"""
_SPLITS = ("train-standard", "train-challenge", "val", "test")
_BASE_URL = "http://data.csail.mit.edu/places/places365/"
# {variant: (archive, md5)}
_DEVKIT_META = {
"standard": ("filelist_places365-standard.tar", "35a0585fee1fa656440f3ab298f8479c"),
"challenge": ("filelist_places365-challenge.tar", "70a8307e459c3de41690a7c76c931734"),
}
# (file, md5)
_CATEGORIES_META = ("categories_places365.txt", "06c963b85866bd0649f97cb43dd16673")
# {split: (file, md5)}
_FILE_LIST_META = {
"train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"),
"train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"),
"val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"),
"test": ("places365_test.txt", "2fce8233fe493576d724142e45d93653"),
}
# {(split, small): (file, md5)}
_IMAGES_META = {
("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"),
("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"),
("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"),
("test", False): ("test_large.tar", "41a4b6b724b1d2cd862fb3871ed59913"),
("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"),
("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"),
("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"),
("test", True): ("test_256.tar", "f532f6ad7b582262a2ec8009075e186b"),
}
def __init__(
self,
root: str,
split: str = "train-standard",
small: bool = False,
download: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.split = self._verify_split(split)
self.small = small
self.loader = loader
self.classes, self.class_to_idx = self.load_categories(download)
self.imgs, self.targets = self.load_file_list(download)
if download:
self.download_images()
def __getitem__(self, index: int) -> Tuple[Any, Any]:
file, target = self.imgs[index]
image = self.loader(file)
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self) -> int:
return len(self.imgs)
@property
def images_dir(self) -> str:
file, _ = self._IMAGES_META[(self.split, self.small)]
return path.join(self.root, path.splitext(file)[0])
def load_categories(self, download: bool = True) -> Tuple[List[str], Dict[str, int]]:
def process(line: str) -> Tuple[str, int]:
cls, idx = line.split()
return cls, int(idx)
file, md5 = self._CATEGORIES_META
file = path.join(self.root, file)
if not self._check_integrity(file, md5, download):
self.download_devkit()
with open(file, "r") as fh:
class_to_idx = dict(process(line) for line in fh)
return sorted(class_to_idx.keys()), class_to_idx
def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]:
def fix_path(path: str) -> str:
if not path.startswith("/"):
return path
path = path[1:]
if os.sep == "/":
return path
return path.replace("/", os.sep)
def process(line: str) -> Tuple[str, int]:
image, idx = line.split()
return path.join(self.images_dir, fix_path(image)), int(idx)
file, md5 = self._FILE_LIST_META[self.split]
file = path.join(self.root, file)
if not self._check_integrity(file, md5, download):
self.download_devkit()
with open(file, "r") as fh:
images = [process(line) for line in fh]
_, targets = zip(*images)
return images, list(targets)
def download_devkit(self) -> None:
file, md5 = self._DEVKIT_META["challenge" if self.split == "train-challenge" else "standard"]
download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
def download_images(self) -> None:
if path.exists(self.images_dir):
raise RuntimeError(
f"The directory {self.images_dir} already exists. If you want to re-download or re-extract the images, "
f"delete the directory."
)
file, md5 = self._IMAGES_META[(self.split, self.small)]
download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, extract_root=self.images_dir, md5=md5)
def extra_repr(self) -> str:
return "\n".join(("Split: {split}", "Small: {small}")).format(**self.__dict__)
def _verify_split(self, split: str) -> str:
return verify_str_arg(split, "split", self._SPLITS)
def _check_integrity(self, file: str, md5: str, download: bool) -> bool:
integrity = check_integrity(path.join(self.root, file), md5=md5)
if not integrity and not download:
raise RuntimeError(
f"The file {file} does not exist or is corrupted. You can set download=True to download it."
)
return integrity
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment