Unverified Commit 4b0b332a authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Ported places365 dataset's tests to the new test framework (#3705)



* Ported places365 dataset's tests to the new test framework

* Made some attributes private

* Removed unnecessary compute_md5()

* Added test_images_download_preexisting()
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 4150ceae
...@@ -208,103 +208,3 @@ def widerface_root(): ...@@ -208,103 +208,3 @@ def widerface_root():
_make_annotations_archive(root_base) _make_annotations_archive(root_base)
yield root yield root
@contextlib.contextmanager
def places365_root(split="train-standard", small=False):
VARIANTS = {
"train-standard": "standard",
"train-challenge": "challenge",
"val": "standard",
}
# {split: file}
DEVKITS = {
"train-standard": "filelist_places365-standard.tar",
"train-challenge": "filelist_places365-challenge.tar",
"val": "filelist_places365-standard.tar",
}
CATEGORIES = "categories_places365.txt"
# {split: file}
FILE_LISTS = {
"train-standard": "places365_train_standard.txt",
"train-challenge": "places365_train_challenge.txt",
"val": "places365_train_standard.txt",
}
# {(split, small): (archive, folder_default, folder_renamed)}
IMAGES = {
("train-standard", False): ("train_large_places365standard.tar", "data_large", "data_large_standard"),
("train-challenge", False): ("train_large_places365challenge.tar", "data_large", "data_large_challenge"),
("val", False): ("val_large.tar", "val_large", "val_large"),
("train-standard", True): ("train_256_places365standard.tar", "data_256", "data_256_standard"),
("train-challenge", True): ("train_256_places365challenge.tar", "data_256", "data_256_challenge"),
("val", True): ("val_256.tar", "val_256", "val_256"),
}
# (class, idx)
CATEGORIES_CONTENT = (("/a/airfield", 0), ("/a/apartment_building/outdoor", 8), ("/b/badlands", 30))
# (file, idx)
FILE_LIST_CONTENT = (
("Places365_val_00000001.png", 0),
*((f"{category}/Places365_train_00000001.png", idx) for category, idx in CATEGORIES_CONTENT),
)
def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
return f"{partial}.{attr}"
def make_txt(root, name, seq):
file = os.path.join(root, name)
with open(file, "w") as fh:
for string, idx in seq:
fh.write(f"{string} {idx}\n")
return name, compute_md5(file)
def make_categories_txt(root, name):
return make_txt(root, name, CATEGORIES_CONTENT)
def make_file_list_txt(root, name):
return make_txt(root, name, FILE_LIST_CONTENT)
def make_image(file, size):
os.makedirs(os.path.dirname(file), exist_ok=True)
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file)
def make_devkit_archive(stack, root, split):
archive = DEVKITS[split]
files = []
meta = make_categories_txt(root, CATEGORIES)
mock_class_attribute(stack, mock_target("_CATEGORIES_META"), meta)
files.append(meta[0])
meta = {split: make_file_list_txt(root, FILE_LISTS[split])}
mock_class_attribute(stack, mock_target("_FILE_LIST_META"), meta)
files.extend([item[0] for item in meta.values()])
meta = {VARIANTS[split]: make_tar(root, archive, *files)}
mock_class_attribute(stack, mock_target("_DEVKIT_META"), meta)
def make_images_archive(stack, root, split, small):
archive, folder_default, folder_renamed = IMAGES[(split, small)]
image_size = (256, 256) if small else (512, random.randint(512, 1024))
files, idcs = zip(*FILE_LIST_CONTENT)
images = [file.lstrip("/").replace("/", os.sep) for file in files]
for image in images:
make_image(os.path.join(root, folder_default, image), image_size)
meta = {(split, small): make_tar(root, archive, folder_default)}
mock_class_attribute(stack, mock_target("_IMAGES_META"), meta)
return [(os.path.join(root, folder_renamed, image), idx) for image, idx in zip(images, idcs)]
with contextlib.ExitStack() as stack, get_tmp_dir() as root:
make_devkit_archive(stack, root, split)
class_to_idx = dict(CATEGORIES_CONTENT)
classes = list(class_to_idx.keys())
data = {"class_to_idx": class_to_idx, "classes": classes}
data["imgs"] = make_images_archive(stack, root, split, small)
clean_dir(root, ".tar$")
yield root, data
...@@ -9,7 +9,6 @@ from torch._utils_internal import get_file_path_2 ...@@ -9,7 +9,6 @@ from torch._utils_internal import get_file_path_2
import torchvision import torchvision
from torchvision.datasets import utils from torchvision.datasets import utils
from common_utils import get_tmp_dir from common_utils import get_tmp_dir
from fakedata_generation import places365_root
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen from urllib.request import Request, urlopen
import itertools import itertools
...@@ -41,106 +40,6 @@ except ImportError: ...@@ -41,106 +40,6 @@ except ImportError:
HAS_PYAV = False HAS_PYAV = False
class DatasetTestcase(unittest.TestCase):
def generic_classification_dataset_test(self, dataset, num_images=1):
self.assertEqual(len(dataset), num_images)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
def generic_segmentation_dataset_test(self, dataset, num_images=1):
self.assertEqual(len(dataset), num_images)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, PIL.Image.Image))
class Tester(DatasetTestcase):
def test_places365(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"]))
def test_places365_transforms(self):
expected_image = "image"
expected_target = "target"
def transform(image):
return expected_image
def target_transform(target):
return expected_target
with places365_root() as places365:
root, data = places365
dataset = torchvision.datasets.Places365(
root, transform=transform, target_transform=target_transform, download=True
)
actual_image, actual_target = dataset[0]
self.assertEqual(actual_image, expected_image)
self.assertEqual(actual_target, expected_target)
def test_places365_devkit_download(self):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split) as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, split=split, download=True)
with self.subTest("classes"):
self.assertSequenceEqual(dataset.classes, data["classes"])
with self.subTest("class_to_idx"):
self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"])
with self.subTest("imgs"):
self.assertSequenceEqual(dataset.imgs, data["imgs"])
def test_places365_devkit_no_download(self):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split) as places365:
root, data = places365
with self.assertRaises(RuntimeError):
torchvision.datasets.Places365(root, split=split, download=False)
def test_places365_images_download(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with self.subTest(split=split, small=small):
with places365_root(split=split, small=small) as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
assert all(os.path.exists(item[0]) for item in dataset.imgs)
def test_places365_images_download_preexisting(self):
split = "train-standard"
small = False
images_dir = "data_large_standard"
with places365_root(split=split, small=small) as places365:
root, data = places365
os.mkdir(os.path.join(root, images_dir))
with self.assertRaises(RuntimeError):
torchvision.datasets.Places365(root, split=split, small=small, download=True)
def test_places365_repr_smoke(self):
with places365_root() as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, download=True)
self.assertIsInstance(repr(dataset), str)
class STL10TestCase(datasets_utils.ImageDatasetTestCase): class STL10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.STL10 DATASET_CLASS = datasets.STL10
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
...@@ -1763,5 +1662,96 @@ class SvhnTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1763,5 +1662,96 @@ class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
return num_examples return num_examples
class Places365TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Places365
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train-standard", "train-challenge", "val"),
small=(False, True),
)
_CATEGORIES = "categories_places365.txt"
# {split: file}
_FILE_LISTS = {
"train-standard": "places365_train_standard.txt",
"train-challenge": "places365_train_challenge.txt",
"val": "places365_val.txt",
}
# {(split, small): folder_name}
_IMAGES = {
("train-standard", False): "data_large_standard",
("train-challenge", False): "data_large_challenge",
("val", False): "val_large",
("train-standard", True): "data_256_standard",
("train-challenge", True): "data_256_challenge",
("val", True): "val_256",
}
# (class, idx)
_CATEGORIES_CONTENT = (
("/a/airfield", 0),
("/a/apartment_building/outdoor", 8),
("/b/badlands", 30),
)
# (file, idx)
_FILE_LIST_CONTENT = (
("Places365_val_00000001.png", 0),
*((f"{category}/Places365_train_00000001.png", idx)
for category, idx in _CATEGORIES_CONTENT),
)
@staticmethod
def _make_txt(root, name, seq):
file = os.path.join(root, name)
with open(file, "w") as fh:
for text, idx in seq:
fh.write(f"{text} {idx}\n")
@staticmethod
def _make_categories_txt(root, name):
Places365TestCase._make_txt(root, name, Places365TestCase._CATEGORIES_CONTENT)
@staticmethod
def _make_file_list_txt(root, name):
Places365TestCase._make_txt(root, name, Places365TestCase._FILE_LIST_CONTENT)
@staticmethod
def _make_image(file_name, size):
os.makedirs(os.path.dirname(file_name), exist_ok=True)
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file_name)
@staticmethod
def _make_devkit_archive(root, split):
Places365TestCase._make_categories_txt(root, Places365TestCase._CATEGORIES)
Places365TestCase._make_file_list_txt(root, Places365TestCase._FILE_LISTS[split])
@staticmethod
def _make_images_archive(root, split, small):
folder_name = Places365TestCase._IMAGES[(split, small)]
image_size = (256, 256) if small else (512, random.randint(512, 1024))
files, idcs = zip(*Places365TestCase._FILE_LIST_CONTENT)
images = [f.lstrip("/").replace("/", os.sep) for f in files]
for image in images:
Places365TestCase._make_image(os.path.join(root, folder_name, image), image_size)
return [(os.path.join(root, folder_name, image), idx) for image, idx in zip(images, idcs)]
def inject_fake_data(self, tmpdir, config):
self._make_devkit_archive(tmpdir, config['split'])
return len(self._make_images_archive(tmpdir, config['split'], config['small']))
def test_classes(self):
classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT))
with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.classes, classes)
def test_class_to_idx(self):
class_to_idx = dict(self._CATEGORIES_CONTENT)
with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.class_to_idx, class_to_idx)
def test_images_download_preexisting(self):
with self.assertRaises(RuntimeError):
with self.create_dataset({'download': True}):
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment