Unverified Commit ae63bd08 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Revert "Revert "Ported places365 dataset's tests to the new test framework...

Revert "Revert "Ported places365 dataset's tests to the new test framework (#3705)" (#3718)" (#3731)

This reverts commit d4195587.
parent 03f94a69
......@@ -208,103 +208,3 @@ def widerface_root():
_make_annotations_archive(root_base)
yield root
@contextlib.contextmanager
def places365_root(split="train-standard", small=False):
VARIANTS = {
"train-standard": "standard",
"train-challenge": "challenge",
"val": "standard",
}
# {split: file}
DEVKITS = {
"train-standard": "filelist_places365-standard.tar",
"train-challenge": "filelist_places365-challenge.tar",
"val": "filelist_places365-standard.tar",
}
CATEGORIES = "categories_places365.txt"
# {split: file}
FILE_LISTS = {
"train-standard": "places365_train_standard.txt",
"train-challenge": "places365_train_challenge.txt",
"val": "places365_train_standard.txt",
}
# {(split, small): (archive, folder_default, folder_renamed)}
IMAGES = {
("train-standard", False): ("train_large_places365standard.tar", "data_large", "data_large_standard"),
("train-challenge", False): ("train_large_places365challenge.tar", "data_large", "data_large_challenge"),
("val", False): ("val_large.tar", "val_large", "val_large"),
("train-standard", True): ("train_256_places365standard.tar", "data_256", "data_256_standard"),
("train-challenge", True): ("train_256_places365challenge.tar", "data_256", "data_256_challenge"),
("val", True): ("val_256.tar", "val_256", "val_256"),
}
# (class, idx)
CATEGORIES_CONTENT = (("/a/airfield", 0), ("/a/apartment_building/outdoor", 8), ("/b/badlands", 30))
# (file, idx)
FILE_LIST_CONTENT = (
("Places365_val_00000001.png", 0),
*((f"{category}/Places365_train_00000001.png", idx) for category, idx in CATEGORIES_CONTENT),
)
def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
return f"{partial}.{attr}"
def make_txt(root, name, seq):
file = os.path.join(root, name)
with open(file, "w") as fh:
for string, idx in seq:
fh.write(f"{string} {idx}\n")
return name, compute_md5(file)
def make_categories_txt(root, name):
return make_txt(root, name, CATEGORIES_CONTENT)
def make_file_list_txt(root, name):
return make_txt(root, name, FILE_LIST_CONTENT)
def make_image(file, size):
os.makedirs(os.path.dirname(file), exist_ok=True)
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file)
def make_devkit_archive(stack, root, split):
archive = DEVKITS[split]
files = []
meta = make_categories_txt(root, CATEGORIES)
mock_class_attribute(stack, mock_target("_CATEGORIES_META"), meta)
files.append(meta[0])
meta = {split: make_file_list_txt(root, FILE_LISTS[split])}
mock_class_attribute(stack, mock_target("_FILE_LIST_META"), meta)
files.extend([item[0] for item in meta.values()])
meta = {VARIANTS[split]: make_tar(root, archive, *files)}
mock_class_attribute(stack, mock_target("_DEVKIT_META"), meta)
def make_images_archive(stack, root, split, small):
archive, folder_default, folder_renamed = IMAGES[(split, small)]
image_size = (256, 256) if small else (512, random.randint(512, 1024))
files, idcs = zip(*FILE_LIST_CONTENT)
images = [file.lstrip("/").replace("/", os.sep) for file in files]
for image in images:
make_image(os.path.join(root, folder_default, image), image_size)
meta = {(split, small): make_tar(root, archive, folder_default)}
mock_class_attribute(stack, mock_target("_IMAGES_META"), meta)
return [(os.path.join(root, folder_renamed, image), idx) for image, idx in zip(images, idcs)]
with contextlib.ExitStack() as stack, get_tmp_dir() as root:
make_devkit_archive(stack, root, split)
class_to_idx = dict(CATEGORIES_CONTENT)
classes = list(class_to_idx.keys())
data = {"class_to_idx": class_to_idx, "classes": classes}
data["imgs"] = make_images_archive(stack, root, split, small)
clean_dir(root, ".tar$")
yield root, data
......@@ -9,7 +9,6 @@ from torch._utils_internal import get_file_path_2
import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import places365_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
......@@ -41,106 +40,6 @@ except ImportError:
HAS_PYAV = False
class DatasetTestcase(unittest.TestCase):
def generic_classification_dataset_test(self, dataset, num_images=1):
self.assertEqual(len(dataset), num_images)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
def generic_segmentation_dataset_test(self, dataset, num_images=1):
self.assertEqual(len(dataset), num_images)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, PIL.Image.Image))
class Tester(DatasetTestcase):
def test_places365(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"]))
def test_places365_transforms(self):
expected_image = "image"
expected_target = "target"
def transform(image):
return expected_image
def target_transform(target):
return expected_target
with places365_root() as places365:
root, data = places365
dataset = torchvision.datasets.Places365(
root, transform=transform, target_transform=target_transform, download=True
)
actual_image, actual_target = dataset[0]
self.assertEqual(actual_image, expected_image)
self.assertEqual(actual_target, expected_target)
def test_places365_devkit_download(self):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split) as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, split=split, download=True)
with self.subTest("classes"):
self.assertSequenceEqual(dataset.classes, data["classes"])
with self.subTest("class_to_idx"):
self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"])
with self.subTest("imgs"):
self.assertSequenceEqual(dataset.imgs, data["imgs"])
def test_places365_devkit_no_download(self):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split) as places365:
root, data = places365
with self.assertRaises(RuntimeError):
torchvision.datasets.Places365(root, split=split, download=False)
def test_places365_images_download(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with self.subTest(split=split, small=small):
with places365_root(split=split, small=small) as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
assert all(os.path.exists(item[0]) for item in dataset.imgs)
def test_places365_images_download_preexisting(self):
split = "train-standard"
small = False
images_dir = "data_large_standard"
with places365_root(split=split, small=small) as places365:
root, data = places365
os.mkdir(os.path.join(root, images_dir))
with self.assertRaises(RuntimeError):
torchvision.datasets.Places365(root, split=split, small=small, download=True)
def test_places365_repr_smoke(self):
with places365_root() as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, download=True)
self.assertIsInstance(repr(dataset), str)
class STL10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.STL10
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
......@@ -1763,5 +1662,96 @@ class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
return num_examples
class Places365TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Places365
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train-standard", "train-challenge", "val"),
small=(False, True),
)
_CATEGORIES = "categories_places365.txt"
# {split: file}
_FILE_LISTS = {
"train-standard": "places365_train_standard.txt",
"train-challenge": "places365_train_challenge.txt",
"val": "places365_val.txt",
}
# {(split, small): folder_name}
_IMAGES = {
("train-standard", False): "data_large_standard",
("train-challenge", False): "data_large_challenge",
("val", False): "val_large",
("train-standard", True): "data_256_standard",
("train-challenge", True): "data_256_challenge",
("val", True): "val_256",
}
# (class, idx)
_CATEGORIES_CONTENT = (
("/a/airfield", 0),
("/a/apartment_building/outdoor", 8),
("/b/badlands", 30),
)
# (file, idx)
_FILE_LIST_CONTENT = (
("Places365_val_00000001.png", 0),
*((f"{category}/Places365_train_00000001.png", idx)
for category, idx in _CATEGORIES_CONTENT),
)
@staticmethod
def _make_txt(root, name, seq):
file = os.path.join(root, name)
with open(file, "w") as fh:
for text, idx in seq:
fh.write(f"{text} {idx}\n")
@staticmethod
def _make_categories_txt(root, name):
Places365TestCase._make_txt(root, name, Places365TestCase._CATEGORIES_CONTENT)
@staticmethod
def _make_file_list_txt(root, name):
Places365TestCase._make_txt(root, name, Places365TestCase._FILE_LIST_CONTENT)
@staticmethod
def _make_image(file_name, size):
os.makedirs(os.path.dirname(file_name), exist_ok=True)
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file_name)
@staticmethod
def _make_devkit_archive(root, split):
Places365TestCase._make_categories_txt(root, Places365TestCase._CATEGORIES)
Places365TestCase._make_file_list_txt(root, Places365TestCase._FILE_LISTS[split])
@staticmethod
def _make_images_archive(root, split, small):
folder_name = Places365TestCase._IMAGES[(split, small)]
image_size = (256, 256) if small else (512, random.randint(512, 1024))
files, idcs = zip(*Places365TestCase._FILE_LIST_CONTENT)
images = [f.lstrip("/").replace("/", os.sep) for f in files]
for image in images:
Places365TestCase._make_image(os.path.join(root, folder_name, image), image_size)
return [(os.path.join(root, folder_name, image), idx) for image, idx in zip(images, idcs)]
def inject_fake_data(self, tmpdir, config):
self._make_devkit_archive(tmpdir, config['split'])
return len(self._make_images_archive(tmpdir, config['split'], config['small']))
def test_classes(self):
classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT))
with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.classes, classes)
def test_class_to_idx(self):
class_to_idx = dict(self._CATEGORIES_CONTENT)
with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.class_to_idx, class_to_idx)
def test_images_download_preexisting(self):
with self.assertRaises(RuntimeError):
with self.create_dataset({'download': True}):
pass
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment