Unverified Commit 4076d7bf authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add tests for LSUN (#3454)


Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent e987071c
......@@ -23,6 +23,8 @@ import torch
import shutil
import json
import random
import string
import io
try:
......@@ -954,5 +956,85 @@ class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
fh.writelines(f"{file}\n" for file in sorted(video_files))
class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.LSUN
REQUIRED_PACKAGES = ("lmdb",)
CONFIGS = datasets_utils.combinations_grid(
classes=("train", "test", "val", ["bedroom_train", "church_outdoor_train"])
)
_CATEGORIES = (
"bedroom",
"bridge",
"church_outdoor",
"classroom",
"conference_room",
"dining_room",
"kitchen",
"living_room",
"restaurant",
"tower",
)
def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir)
num_images = 0
for cls in self._parse_classes(config["classes"]):
num_images += self._create_lmdb(root, cls)
return num_images
@contextlib.contextmanager
def create_dataset(
self,
*args, **kwargs
):
with super().create_dataset(*args, **kwargs) as output:
yield output
# Currently datasets.LSUN caches the keys in the current directory rather than in the root directory. Thus,
# this creates a number of unique _cache_* files in the current directory that will not be removed together
# with the temporary directory
for file in os.listdir(os.getcwd()):
if file.startswith("_cache_"):
os.remove(file)
def _parse_classes(self, classes):
if not isinstance(classes, str):
return classes
split = classes
if split == "test":
return [split]
return [f"{category}_{split}" for category in self._CATEGORIES]
def _create_lmdb(self, root, cls):
lmdb = datasets_utils.lazy_importer.lmdb
hexdigits_lowercase = string.digits + string.ascii_lowercase[:6]
folder = f"{cls}_lmdb"
num_images = torch.randint(1, 4, size=()).item()
format = "webp"
files = datasets_utils.create_image_folder(root, folder, lambda idx: f"{idx}.{format}", num_images)
with lmdb.open(str(root / folder)) as env, env.begin(write=True) as txn:
for file in files:
key = "".join(random.choice(hexdigits_lowercase) for _ in range(40)).encode()
buffer = io.BytesIO()
Image.open(file).save(buffer, format)
buffer.seek(0)
value = buffer.read()
txn.put(key, value)
os.remove(file)
return num_images
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment