Unverified Commit 6f028212 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Fix Places365 dataset (#2625)

* fix images extraction

* remove test split

* fix tests

* be less clever in test data generation

* remove micro optimization

* lint
parent be8192e2
...@@ -12,6 +12,7 @@ from itertools import cycle ...@@ -12,6 +12,7 @@ from itertools import cycle
from torchvision.io.video import write_video from torchvision.io.video import write_video
import unittest.mock import unittest.mock
import hashlib import hashlib
from distutils import dir_util
@contextlib.contextmanager @contextlib.contextmanager
...@@ -318,101 +319,127 @@ def ucf101_root(): ...@@ -318,101 +319,127 @@ def ucf101_root():
@contextlib.contextmanager @contextlib.contextmanager
def places365_root(split="train-standard", small=False, extract_images=True): def places365_root(split="train-standard", small=False, extract_images=True):
CATEGORIES = (("/a/airfield", 0), ("/a/apartment_building/outdoor", 8), ("/b/badlands", 30)) VARIANTS = {
FILE_LIST = [(f"{idx}.png", idx) for idx in tuple(zip(*CATEGORIES))[1]] "train-standard": "standard",
"train-challenge": "challenge",
"val": "standard",
}
# {split: file}
DEVKITS = {
"train-standard": "filelist_places365-standard.tar",
"train-challenge": "filelist_places365-challenge.tar",
"val": "filelist_places365-standard.tar",
}
CATEGORIES = "categories_places365.txt"
# {split: file}
FILE_LISTS = {
"train-standard": "places365_train_standard.txt",
"train-challenge": "places365_train_challenge.txt",
"val": "places365_train_standard.txt",
}
# {(split, small): (archive, folder_default, folder_renamed)}
IMAGES = {
("train-standard", False): ("train_large_places365standard.tar", "data_large", "data_large_standard"),
("train-challenge", False): ("train_large_places365challenge.tar", "data_large", "data_large_challenge"),
("val", False): ("val_large.tar", "val_large", "val_large"),
("train-standard", True): ("train_256_places365standard.tar", "data_256", "data_256_standard"),
("train-challenge", True): ("train_256_places365challenge.tar", "data_256", "data_256_challenge"),
("val", True): ("val_256.tar", "val_256", "val_256"),
}
# (class, idx)
CATEGORIES_CONTENT = (("/a/airfield", 0), ("/a/apartment_building/outdoor", 8), ("/b/badlands", 30))
# (file, idx)
FILE_LIST_CONTENT = (
("Places365_val_00000001.png", 0),
*((f"{category}/Places365_train_00000001.png", idx) for category, idx in CATEGORIES_CONTENT),
)
def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
return f"{partial}.{attr}"
def mock_class_attribute(stack, attr, new):
mock = unittest.mock.patch(mock_target(attr), new_callable=unittest.mock.PropertyMock, return_value=new)
stack.enter_context(mock)
return mock
def compute_md5(file): def compute_md5(file):
with open(file, "rb") as fh: with open(file, "rb") as fh:
return hashlib.md5(fh.read()).hexdigest() return hashlib.md5(fh.read()).hexdigest()
def make_txt(root, name, cls_or_image_seq): def make_txt(root, name, seq):
file = os.path.join(root, name) file = os.path.join(root, name)
with open(file, "w") as fh: with open(file, "w") as fh:
for cls_or_image, idx in cls_or_image_seq: for string, idx in seq:
fh.write(f"{cls_or_image} {idx}\n") fh.write(f"{string} {idx}\n")
return name, compute_md5(file) return name, compute_md5(file)
def make_categories_txt(root, name): def make_categories_txt(root, name):
return make_txt(root, name, CATEGORIES) return make_txt(root, name, CATEGORIES_CONTENT)
def make_file_list_txt(root, name): def make_file_list_txt(root, name):
return make_txt(root, name, FILE_LIST) return make_txt(root, name, FILE_LIST_CONTENT)
def make_image(root, name, size): def make_image(file, size):
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(os.path.join(root, name)) os.makedirs(os.path.dirname(file), exist_ok=True)
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file)
def make_tar(root, name, files, remove_files=True): def make_tar(root, name, *files, remove_files=True):
name = f"{os.path.splitext(name)[0]}.tar"
archive = os.path.join(root, name) archive = os.path.join(root, name)
files = [os.path.join(root, file) for file in files]
with tarfile.open(archive, "w") as fh: with tarfile.open(archive, "w") as fh:
for file in files: for file in files:
fh.add(file, os.path.basename(file)) fh.add(os.path.join(root, file), arcname=file)
if remove_files: if remove_files:
for file in files: for file in [os.path.join(root, file) for file in files]:
os.remove(file) if os.path.isdir(file):
dir_util.remove_tree(file)
else:
os.remove(file)
return name, compute_md5(archive) return name, compute_md5(archive)
def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
return f"{partial}.{attr}"
def mock_class_attribute(stack, attr, new):
mock = unittest.mock.patch(mock_target(attr), new_callable=unittest.mock.PropertyMock, return_value=new)
stack.enter_context(mock)
return mock
def split_to_variant(split):
return "challenge" if split == "train-challenge" else "standard"
def make_devkit_archive(stack, root, split): def make_devkit_archive(stack, root, split):
variant = split_to_variant(split) archive = DEVKITS[split]
archive = f"filelist_places365-{variant}.tar"
files = [] files = []
meta = make_categories_txt(root, "categories_places365.txt") meta = make_categories_txt(root, CATEGORIES)
mock_class_attribute(stack, "_CATEGORIES_META", meta) mock_class_attribute(stack, "_CATEGORIES_META", meta)
files.append(meta[0]) files.append(meta[0])
meta = { meta = {split: make_file_list_txt(root, FILE_LISTS[split])}
split: make_file_list_txt(root, f"places365_{split.replace('-', '_')}.txt")
for split in (f"train-{variant}", "val", "test")
}
mock_class_attribute(stack, "_FILE_LIST_META", meta) mock_class_attribute(stack, "_FILE_LIST_META", meta)
files.extend([item[0] for item in meta.values()]) files.extend([item[0] for item in meta.values()])
meta = {variant: make_tar(root, archive, files)} meta = {VARIANTS[split]: make_tar(root, archive, *files)}
mock_class_attribute(stack, "_DEVKIT_META", meta) mock_class_attribute(stack, "_DEVKIT_META", meta)
def make_images_archive(stack, root, split, small): def make_images_archive(stack, root, split, small):
if split.startswith("train"): archive, folder_default, folder_renamed = IMAGES[(split, small)]
images_dir = f"train_{'256' if small else 'large'}_places365{split_to_variant(split)}"
else:
images_dir = f"{split}_{'256' if small else 'large'}"
archive = f"{images_dir}.tar"
size = (256, 256) if small else (512, random.randint(512, 1024)) image_size = (256, 256) if small else (512, random.randint(512, 1024))
imgs = [item[0] for item in FILE_LIST] files, idcs = zip(*FILE_LIST_CONTENT)
for img in imgs: images = [file.lstrip("/").replace("/", os.sep) for file in files]
make_image(root, img, size) for image in images:
make_image(os.path.join(root, folder_default, image), image_size)
meta = {(split, small): make_tar(root, archive, imgs)} meta = {(split, small): make_tar(root, archive, folder_default)}
mock_class_attribute(stack, "_IMAGES_META", meta) mock_class_attribute(stack, "_IMAGES_META", meta)
return images_dir return [(os.path.join(root, folder_renamed, image), idx) for image, idx in zip(images, idcs)]
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack, get_tmp_dir() as root:
with get_tmp_dir() as root: make_devkit_archive(stack, root, split)
make_devkit_archive(stack, root, split) class_to_idx = dict(CATEGORIES_CONTENT)
class_to_idx = dict(CATEGORIES) classes = list(class_to_idx.keys())
classes = list(class_to_idx.keys()) data = {"class_to_idx": class_to_idx, "classes": classes}
data = {"class_to_idx": class_to_idx, "classes": classes}
if extract_images: if extract_images:
images_dir = make_images_archive(stack, root, split, small) data["imgs"] = make_images_archive(stack, root, split, small)
data["imgs"] = [(os.path.join(root, images_dir, file), idx) for file, idx in FILE_LIST] else:
else: stack.enter_context(unittest.mock.patch(mock_target("download_images")))
stack.enter_context(unittest.mock.patch(mock_target("download_images"))) data["imgs"] = None
yield root, data yield root, data
...@@ -283,7 +283,7 @@ class Tester(unittest.TestCase): ...@@ -283,7 +283,7 @@ class Tester(unittest.TestCase):
self.assertEqual(label, 1) self.assertEqual(label, 1)
def test_places365(self): def test_places365(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val", "test"), (False, True)): for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365: with places365_root(split=split, small=small) as places365:
root, data = places365 root, data = places365
...@@ -313,7 +313,7 @@ class Tester(unittest.TestCase): ...@@ -313,7 +313,7 @@ class Tester(unittest.TestCase):
@mock.patch("torchvision.datasets.utils.download_url") @mock.patch("torchvision.datasets.utils.download_url")
def test_places365_downloadable(self, download_url): def test_places365_downloadable(self, download_url):
for split, small in itertools.product(("train-standard", "train-challenge", "val", "test"), (False, True)): for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365: with places365_root(split=split, small=small) as places365:
root, data = places365 root, data = places365
...@@ -326,7 +326,7 @@ class Tester(unittest.TestCase): ...@@ -326,7 +326,7 @@ class Tester(unittest.TestCase):
assert response.code == 200, f"Server returned status code {response.code} for {url}." assert response.code == 200, f"Server returned status code {response.code} for {url}."
def test_places365_devkit_download(self): def test_places365_devkit_download(self):
for split in ("train-standard", "train-challenge", "val", "test"): for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split): with self.subTest(split=split):
with places365_root(split=split) as places365: with places365_root(split=split) as places365:
root, data = places365 root, data = places365
...@@ -343,7 +343,7 @@ class Tester(unittest.TestCase): ...@@ -343,7 +343,7 @@ class Tester(unittest.TestCase):
self.assertSequenceEqual(dataset.imgs, data["imgs"]) self.assertSequenceEqual(dataset.imgs, data["imgs"])
def test_places365_devkit_no_download(self): def test_places365_devkit_no_download(self):
for split in ("train-standard", "train-challenge", "val", "test"): for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split): with self.subTest(split=split):
with places365_root(split=split, extract_images=False) as places365: with places365_root(split=split, extract_images=False) as places365:
root, data = places365 root, data = places365
...@@ -352,7 +352,7 @@ class Tester(unittest.TestCase): ...@@ -352,7 +352,7 @@ class Tester(unittest.TestCase):
torchvision.datasets.Places365(root, split=split, download=False) torchvision.datasets.Places365(root, split=split, download=False)
def test_places365_images_download(self): def test_places365_images_download(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val", "test"), (False, True)): for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with self.subTest(split=split, small=small): with self.subTest(split=split, small=small):
with places365_root(split=split, small=small) as places365: with places365_root(split=split, small=small) as places365:
root, data = places365 root, data = places365
...@@ -364,7 +364,7 @@ class Tester(unittest.TestCase): ...@@ -364,7 +364,7 @@ class Tester(unittest.TestCase):
def test_places365_images_download_preexisting(self): def test_places365_images_download_preexisting(self):
split = "train-standard" split = "train-standard"
small = False small = False
images_dir = "train_large_places365standard" images_dir = "data_large_standard"
with places365_root(split=split, small=small) as places365: with places365_root(split=split, small=small) as places365:
root, data = places365 root, data = places365
......
...@@ -14,7 +14,7 @@ class Places365(VisionDataset): ...@@ -14,7 +14,7 @@ class Places365(VisionDataset):
Args: Args:
root (string): Root directory of the Places365 dataset. root (string): Root directory of the Places365 dataset.
split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challendge``, split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challendge``,
``val``, and ``test``. ``val``.
small (bool, optional): If ``True``, uses the small images, i. e. resized to 256 x 256 pixels, instead of the small (bool, optional): If ``True``, uses the small images, i. e. resized to 256 x 256 pixels, instead of the
high resolution ones. high resolution ones.
download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
...@@ -35,7 +35,7 @@ class Places365(VisionDataset): ...@@ -35,7 +35,7 @@ class Places365(VisionDataset):
RuntimeError: If ``download is False`` and the meta files, i. e. the devkit, are not present or corrupted. RuntimeError: If ``download is False`` and the meta files, i. e. the devkit, are not present or corrupted.
RuntimeError: If ``download is True`` and the image archive is already extracted. RuntimeError: If ``download is True`` and the image archive is already extracted.
""" """
_SPLITS = ("train-standard", "train-challenge", "val", "test") _SPLITS = ("train-standard", "train-challenge", "val")
_BASE_URL = "http://data.csail.mit.edu/places/places365/" _BASE_URL = "http://data.csail.mit.edu/places/places365/"
# {variant: (archive, md5)} # {variant: (archive, md5)}
_DEVKIT_META = { _DEVKIT_META = {
...@@ -49,18 +49,15 @@ class Places365(VisionDataset): ...@@ -49,18 +49,15 @@ class Places365(VisionDataset):
"train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"), "train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"),
"train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"), "train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"),
"val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"), "val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"),
"test": ("places365_test.txt", "2fce8233fe493576d724142e45d93653"),
} }
# {(split, small): (file, md5)} # {(split, small): (file, md5)}
_IMAGES_META = { _IMAGES_META = {
("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"), ("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"),
("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"), ("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"),
("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"), ("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"),
("test", False): ("test_large.tar", "41a4b6b724b1d2cd862fb3871ed59913"),
("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"), ("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"),
("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"), ("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"),
("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"), ("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"),
("test", True): ("test_256.tar", "f532f6ad7b582262a2ec8009075e186b"),
} }
def __init__( def __init__(
...@@ -97,10 +94,18 @@ class Places365(VisionDataset): ...@@ -97,10 +94,18 @@ class Places365(VisionDataset):
def __len__(self) -> int: def __len__(self) -> int:
return len(self.imgs) return len(self.imgs)
@property
def variant(self) -> str:
return "challenge" if "challenge" in self.split else "standard"
@property @property
def images_dir(self) -> str: def images_dir(self) -> str:
file, _ = self._IMAGES_META[(self.split, self.small)] size = "256" if self.small else "large"
return path.join(self.root, path.splitext(file)[0]) if self.split.startswith("train"):
dir = f"data_{size}_{self.variant}"
else:
dir = f"{self.split}_{size}"
return path.join(self.root, dir)
def load_categories(self, download: bool = True) -> Tuple[List[str], Dict[str, int]]: def load_categories(self, download: bool = True) -> Tuple[List[str], Dict[str, int]]:
def process(line: str) -> Tuple[str, int]: def process(line: str) -> Tuple[str, int]:
...@@ -118,20 +123,9 @@ class Places365(VisionDataset): ...@@ -118,20 +123,9 @@ class Places365(VisionDataset):
return sorted(class_to_idx.keys()), class_to_idx return sorted(class_to_idx.keys()), class_to_idx
def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]: def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]:
def fix_path(path: str) -> str: def process(line: str, sep="/") -> Tuple[str, int]:
if not path.startswith("/"):
return path
path = path[1:]
if os.sep == "/":
return path
return path.replace("/", os.sep)
def process(line: str) -> Tuple[str, int]:
image, idx = line.split() image, idx = line.split()
return path.join(self.images_dir, fix_path(image)), int(idx) return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), int(idx)
file, md5 = self._FILE_LIST_META[self.split] file, md5 = self._FILE_LIST_META[self.split]
file = path.join(self.root, file) file = path.join(self.root, file)
...@@ -145,7 +139,7 @@ class Places365(VisionDataset): ...@@ -145,7 +139,7 @@ class Places365(VisionDataset):
return images, list(targets) return images, list(targets)
def download_devkit(self) -> None: def download_devkit(self) -> None:
file, md5 = self._DEVKIT_META["challenge" if self.split == "train-challenge" else "standard"] file, md5 = self._DEVKIT_META[self.variant]
download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5) download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
def download_images(self) -> None: def download_images(self) -> None:
...@@ -156,7 +150,10 @@ class Places365(VisionDataset): ...@@ -156,7 +150,10 @@ class Places365(VisionDataset):
) )
file, md5 = self._IMAGES_META[(self.split, self.small)] file, md5 = self._IMAGES_META[(self.split, self.small)]
download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, extract_root=self.images_dir, md5=md5) download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
if self.split.startswith("train"):
os.rename(self.images_dir.rsplit("_", 1)[0], self.images_dir)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return "\n".join(("Split: {split}", "Small: {small}")).format(**self.__dict__) return "\n".join(("Split: {split}", "Small: {small}")).format(**self.__dict__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment