Unverified Commit 6f028212 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Fix Places365 dataset (#2625)

* fix images extraction

* remove test split

* fix tests

* be less clever in test data generation

* remove micro optimization

* lint
parent be8192e2
......@@ -12,6 +12,7 @@ from itertools import cycle
from torchvision.io.video import write_video
import unittest.mock
import hashlib
from distutils import dir_util
@contextlib.contextmanager
......@@ -318,101 +319,127 @@ def ucf101_root():
@contextlib.contextmanager
def places365_root(split="train-standard", small=False, extract_images=True):
CATEGORIES = (("/a/airfield", 0), ("/a/apartment_building/outdoor", 8), ("/b/badlands", 30))
FILE_LIST = [(f"{idx}.png", idx) for idx in tuple(zip(*CATEGORIES))[1]]
VARIANTS = {
"train-standard": "standard",
"train-challenge": "challenge",
"val": "standard",
}
# {split: file}
DEVKITS = {
"train-standard": "filelist_places365-standard.tar",
"train-challenge": "filelist_places365-challenge.tar",
"val": "filelist_places365-standard.tar",
}
CATEGORIES = "categories_places365.txt"
# {split: file}
FILE_LISTS = {
"train-standard": "places365_train_standard.txt",
"train-challenge": "places365_train_challenge.txt",
"val": "places365_train_standard.txt",
}
# {(split, small): (archive, folder_default, folder_renamed)}
IMAGES = {
("train-standard", False): ("train_large_places365standard.tar", "data_large", "data_large_standard"),
("train-challenge", False): ("train_large_places365challenge.tar", "data_large", "data_large_challenge"),
("val", False): ("val_large.tar", "val_large", "val_large"),
("train-standard", True): ("train_256_places365standard.tar", "data_256", "data_256_standard"),
("train-challenge", True): ("train_256_places365challenge.tar", "data_256", "data_256_challenge"),
("val", True): ("val_256.tar", "val_256", "val_256"),
}
# (class, idx)
CATEGORIES_CONTENT = (("/a/airfield", 0), ("/a/apartment_building/outdoor", 8), ("/b/badlands", 30))
# (file, idx)
FILE_LIST_CONTENT = (
("Places365_val_00000001.png", 0),
*((f"{category}/Places365_train_00000001.png", idx) for category, idx in CATEGORIES_CONTENT),
)
def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
return f"{partial}.{attr}"
def mock_class_attribute(stack, attr, new):
mock = unittest.mock.patch(mock_target(attr), new_callable=unittest.mock.PropertyMock, return_value=new)
stack.enter_context(mock)
return mock
def compute_md5(file):
with open(file, "rb") as fh:
return hashlib.md5(fh.read()).hexdigest()
def make_txt(root, name, cls_or_image_seq):
def make_txt(root, name, seq):
file = os.path.join(root, name)
with open(file, "w") as fh:
for cls_or_image, idx in cls_or_image_seq:
fh.write(f"{cls_or_image} {idx}\n")
for string, idx in seq:
fh.write(f"{string} {idx}\n")
return name, compute_md5(file)
def make_categories_txt(root, name):
return make_txt(root, name, CATEGORIES)
return make_txt(root, name, CATEGORIES_CONTENT)
def make_file_list_txt(root, name):
return make_txt(root, name, FILE_LIST)
return make_txt(root, name, FILE_LIST_CONTENT)
def make_image(root, name, size):
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(os.path.join(root, name))
def make_image(file, size):
os.makedirs(os.path.dirname(file), exist_ok=True)
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file)
def make_tar(root, name, files, remove_files=True):
def make_tar(root, name, *files, remove_files=True):
name = f"{os.path.splitext(name)[0]}.tar"
archive = os.path.join(root, name)
files = [os.path.join(root, file) for file in files]
with tarfile.open(archive, "w") as fh:
for file in files:
fh.add(file, os.path.basename(file))
fh.add(os.path.join(root, file), arcname=file)
if remove_files:
for file in files:
for file in [os.path.join(root, file) for file in files]:
if os.path.isdir(file):
dir_util.remove_tree(file)
else:
os.remove(file)
return name, compute_md5(archive)
def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
return f"{partial}.{attr}"
def mock_class_attribute(stack, attr, new):
mock = unittest.mock.patch(mock_target(attr), new_callable=unittest.mock.PropertyMock, return_value=new)
stack.enter_context(mock)
return mock
def split_to_variant(split):
return "challenge" if split == "train-challenge" else "standard"
def make_devkit_archive(stack, root, split):
variant = split_to_variant(split)
archive = f"filelist_places365-{variant}.tar"
archive = DEVKITS[split]
files = []
meta = make_categories_txt(root, "categories_places365.txt")
meta = make_categories_txt(root, CATEGORIES)
mock_class_attribute(stack, "_CATEGORIES_META", meta)
files.append(meta[0])
meta = {
split: make_file_list_txt(root, f"places365_{split.replace('-', '_')}.txt")
for split in (f"train-{variant}", "val", "test")
}
meta = {split: make_file_list_txt(root, FILE_LISTS[split])}
mock_class_attribute(stack, "_FILE_LIST_META", meta)
files.extend([item[0] for item in meta.values()])
meta = {variant: make_tar(root, archive, files)}
meta = {VARIANTS[split]: make_tar(root, archive, *files)}
mock_class_attribute(stack, "_DEVKIT_META", meta)
def make_images_archive(stack, root, split, small):
if split.startswith("train"):
images_dir = f"train_{'256' if small else 'large'}_places365{split_to_variant(split)}"
else:
images_dir = f"{split}_{'256' if small else 'large'}"
archive = f"{images_dir}.tar"
archive, folder_default, folder_renamed = IMAGES[(split, small)]
size = (256, 256) if small else (512, random.randint(512, 1024))
imgs = [item[0] for item in FILE_LIST]
for img in imgs:
make_image(root, img, size)
image_size = (256, 256) if small else (512, random.randint(512, 1024))
files, idcs = zip(*FILE_LIST_CONTENT)
images = [file.lstrip("/").replace("/", os.sep) for file in files]
for image in images:
make_image(os.path.join(root, folder_default, image), image_size)
meta = {(split, small): make_tar(root, archive, imgs)}
meta = {(split, small): make_tar(root, archive, folder_default)}
mock_class_attribute(stack, "_IMAGES_META", meta)
return images_dir
return [(os.path.join(root, folder_renamed, image), idx) for image, idx in zip(images, idcs)]
with contextlib.ExitStack() as stack:
with get_tmp_dir() as root:
with contextlib.ExitStack() as stack, get_tmp_dir() as root:
make_devkit_archive(stack, root, split)
class_to_idx = dict(CATEGORIES)
class_to_idx = dict(CATEGORIES_CONTENT)
classes = list(class_to_idx.keys())
data = {"class_to_idx": class_to_idx, "classes": classes}
if extract_images:
images_dir = make_images_archive(stack, root, split, small)
data["imgs"] = [(os.path.join(root, images_dir, file), idx) for file, idx in FILE_LIST]
data["imgs"] = make_images_archive(stack, root, split, small)
else:
stack.enter_context(unittest.mock.patch(mock_target("download_images")))
data["imgs"] = None
yield root, data
......@@ -283,7 +283,7 @@ class Tester(unittest.TestCase):
self.assertEqual(label, 1)
def test_places365(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val", "test"), (False, True)):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365
......@@ -313,7 +313,7 @@ class Tester(unittest.TestCase):
@mock.patch("torchvision.datasets.utils.download_url")
def test_places365_downloadable(self, download_url):
for split, small in itertools.product(("train-standard", "train-challenge", "val", "test"), (False, True)):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365
......@@ -326,7 +326,7 @@ class Tester(unittest.TestCase):
assert response.code == 200, f"Server returned status code {response.code} for {url}."
def test_places365_devkit_download(self):
for split in ("train-standard", "train-challenge", "val", "test"):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split) as places365:
root, data = places365
......@@ -343,7 +343,7 @@ class Tester(unittest.TestCase):
self.assertSequenceEqual(dataset.imgs, data["imgs"])
def test_places365_devkit_no_download(self):
for split in ("train-standard", "train-challenge", "val", "test"):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split, extract_images=False) as places365:
root, data = places365
......@@ -352,7 +352,7 @@ class Tester(unittest.TestCase):
torchvision.datasets.Places365(root, split=split, download=False)
def test_places365_images_download(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val", "test"), (False, True)):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with self.subTest(split=split, small=small):
with places365_root(split=split, small=small) as places365:
root, data = places365
......@@ -364,7 +364,7 @@ class Tester(unittest.TestCase):
def test_places365_images_download_preexisting(self):
split = "train-standard"
small = False
images_dir = "train_large_places365standard"
images_dir = "data_large_standard"
with places365_root(split=split, small=small) as places365:
root, data = places365
......
......@@ -14,7 +14,7 @@ class Places365(VisionDataset):
Args:
root (string): Root directory of the Places365 dataset.
split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challendge``,
``val``, and ``test``.
``val``.
small (bool, optional): If ``True``, uses the small images, i. e. resized to 256 x 256 pixels, instead of the
high resolution ones.
download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
......@@ -35,7 +35,7 @@ class Places365(VisionDataset):
RuntimeError: If ``download is False`` and the meta files, i. e. the devkit, are not present or corrupted.
RuntimeError: If ``download is True`` and the image archive is already extracted.
"""
_SPLITS = ("train-standard", "train-challenge", "val", "test")
_SPLITS = ("train-standard", "train-challenge", "val")
_BASE_URL = "http://data.csail.mit.edu/places/places365/"
# {variant: (archive, md5)}
_DEVKIT_META = {
......@@ -49,18 +49,15 @@ class Places365(VisionDataset):
"train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"),
"train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"),
"val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"),
"test": ("places365_test.txt", "2fce8233fe493576d724142e45d93653"),
}
# {(split, small): (file, md5)}
_IMAGES_META = {
("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"),
("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"),
("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"),
("test", False): ("test_large.tar", "41a4b6b724b1d2cd862fb3871ed59913"),
("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"),
("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"),
("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"),
("test", True): ("test_256.tar", "f532f6ad7b582262a2ec8009075e186b"),
}
def __init__(
......@@ -97,10 +94,18 @@ class Places365(VisionDataset):
def __len__(self) -> int:
return len(self.imgs)
@property
def variant(self) -> str:
return "challenge" if "challenge" in self.split else "standard"
@property
def images_dir(self) -> str:
file, _ = self._IMAGES_META[(self.split, self.small)]
return path.join(self.root, path.splitext(file)[0])
size = "256" if self.small else "large"
if self.split.startswith("train"):
dir = f"data_{size}_{self.variant}"
else:
dir = f"{self.split}_{size}"
return path.join(self.root, dir)
def load_categories(self, download: bool = True) -> Tuple[List[str], Dict[str, int]]:
def process(line: str) -> Tuple[str, int]:
......@@ -118,20 +123,9 @@ class Places365(VisionDataset):
return sorted(class_to_idx.keys()), class_to_idx
def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]:
def fix_path(path: str) -> str:
if not path.startswith("/"):
return path
path = path[1:]
if os.sep == "/":
return path
return path.replace("/", os.sep)
def process(line: str) -> Tuple[str, int]:
def process(line: str, sep="/") -> Tuple[str, int]:
image, idx = line.split()
return path.join(self.images_dir, fix_path(image)), int(idx)
return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), int(idx)
file, md5 = self._FILE_LIST_META[self.split]
file = path.join(self.root, file)
......@@ -145,7 +139,7 @@ class Places365(VisionDataset):
return images, list(targets)
def download_devkit(self) -> None:
file, md5 = self._DEVKIT_META["challenge" if self.split == "train-challenge" else "standard"]
file, md5 = self._DEVKIT_META[self.variant]
download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
def download_images(self) -> None:
......@@ -156,7 +150,10 @@ class Places365(VisionDataset):
)
file, md5 = self._IMAGES_META[(self.split, self.small)]
download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, extract_root=self.images_dir, md5=md5)
download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
if self.split.startswith("train"):
os.rename(self.images_dir.rsplit("_", 1)[0], self.images_dir)
def extra_repr(self) -> str:
return "\n".join(("Split: {split}", "Small: {small}")).format(**self.__dict__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment