import contextlib
import sys
import os
import unittest
from unittest import mock
import numpy as np
import PIL
from PIL import Image
from torch._utils_internal import get_file_path_2
import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root, stl10_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
import datasets_utils
import pathlib
import pickle
from torchvision import datasets
import torch
try:
import scipy
HAS_SCIPY = True
except ImportError:
HAS_SCIPY = False
try:
import av
HAS_PYAV = True
except ImportError:
HAS_PYAV = False
class DatasetTestcase(unittest.TestCase):
def generic_classification_dataset_test(self, dataset, num_images=1):
self.assertEqual(len(dataset), num_images)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
def generic_segmentation_dataset_test(self, dataset, num_images=1):
self.assertEqual(len(dataset), num_images)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, PIL.Image.Image))
class Tester(DatasetTestcase):
def test_imagefolder(self):
# TODO: create the fake data on-the-fly
FAKEDATA_DIR = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata')
with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
classes = sorted(['a', 'b'])
class_a_image_files = [
os.path.join(root, 'a', file) for file in ('a1.png', 'a2.png', 'a3.png')
]
class_b_image_files = [
os.path.join(root, 'b', file) for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')
]
dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x)
# test if all classes are present
self.assertEqual(classes, sorted(dataset.classes))
# test if combination of classes and class_to_index functions correctly
for cls in classes:
self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])
# test if all images were detected correctly
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files]
imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)
# test if the datasets outputs all images correctly
outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)
# redo all tests with specified valid image files
dataset = torchvision.datasets.ImageFolder(
root, loader=lambda x: x, is_valid_file=lambda x: '3' in x)
self.assertEqual(classes, sorted(dataset.classes))
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files
if '3' in img_file]
imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files
if '3' in img_file]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)
outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)
def test_imagefolder_empty(self):
with get_tmp_dir() as root:
with self.assertRaises(RuntimeError):
torchvision.datasets.ImageFolder(root, loader=lambda x: x)
with self.assertRaises(RuntimeError):
torchvision.datasets.ImageFolder(
root, loader=lambda x: x, is_valid_file=lambda x: False
)
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_mnist(self, mock_download_extract):
num_examples = 30
with mnist_root(num_examples, "MNIST") as root:
dataset = torchvision.datasets.MNIST(root, download=True)
self.generic_classification_dataset_test(dataset, num_images=num_examples)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_kmnist(self, mock_download_extract):
num_examples = 30
with mnist_root(num_examples, "KMNIST") as root:
dataset = torchvision.datasets.KMNIST(root, download=True)
self.generic_classification_dataset_test(dataset, num_images=num_examples)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_fashionmnist(self, mock_download_extract):
num_examples = 30
with mnist_root(num_examples, "FashionMNIST") as root:
dataset = torchvision.datasets.FashionMNIST(root, download=True)
self.generic_classification_dataset_test(dataset, num_images=num_examples)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@mock.patch('torchvision.datasets.imagenet._verify_archive')
@unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
def test_imagenet(self, mock_verify):
with imagenet_root() as root:
dataset = torchvision.datasets.ImageNet(root, split='train')
self.generic_classification_dataset_test(dataset)
dataset = torchvision.datasets.ImageNet(root, split='val')
self.generic_classification_dataset_test(dataset)
@mock.patch('torchvision.datasets.WIDERFace._check_integrity')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_widerface(self, mock_check_integrity):
mock_check_integrity.return_value = True
with widerface_root() as root:
dataset = torchvision.datasets.WIDERFace(root, split='train')
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
dataset = torchvision.datasets.WIDERFace(root, split='val')
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
dataset = torchvision.datasets.WIDERFace(root, split='test')
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
@mock.patch('torchvision.datasets.cifar.check_integrity')
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
def test_cifar10(self, mock_ext_check, mock_int_check):
mock_ext_check.return_value = True
mock_int_check.return_value = True
with cifar_root('CIFAR10') as root:
dataset = torchvision.datasets.CIFAR10(root, train=True, download=True)
self.generic_classification_dataset_test(dataset, num_images=5)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
dataset = torchvision.datasets.CIFAR10(root, train=False, download=True)
self.generic_classification_dataset_test(dataset)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@mock.patch('torchvision.datasets.cifar.check_integrity')
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
def test_cifar100(self, mock_ext_check, mock_int_check):
mock_ext_check.return_value = True
mock_int_check.return_value = True
with cifar_root('CIFAR100') as root:
dataset = torchvision.datasets.CIFAR100(root, train=True, download=True)
self.generic_classification_dataset_test(dataset)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
dataset = torchvision.datasets.CIFAR100(root, train=False, download=True)
self.generic_classification_dataset_test(dataset)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_cityscapes(self):
with cityscapes_root() as root:
for mode in ['coarse', 'fine']:
if mode == 'coarse':
splits = ['train', 'train_extra', 'val']
else:
splits = ['train', 'val', 'test']
for split in splits:
for target_type in ['semantic', 'instance']:
dataset = torchvision.datasets.Cityscapes(
root, split=split, target_type=target_type, mode=mode)
self.generic_segmentation_dataset_test(dataset, num_images=2)
color_dataset = torchvision.datasets.Cityscapes(
root, split=split, target_type='color', mode=mode)
color_img, color_target = color_dataset[0]
self.assertTrue(isinstance(color_img, PIL.Image.Image))
self.assertTrue(np.array(color_target).shape[2] == 4)
polygon_dataset = torchvision.datasets.Cityscapes(
root, split=split, target_type='polygon', mode=mode)
polygon_img, polygon_target = polygon_dataset[0]
self.assertTrue(isinstance(polygon_img, PIL.Image.Image))
self.assertTrue(isinstance(polygon_target, dict))
self.assertTrue(isinstance(polygon_target['imgHeight'], int))
self.assertTrue(isinstance(polygon_target['objects'], list))
# Test multiple target types
targets_combo = ['semantic', 'polygon', 'color']
multiple_types_dataset = torchvision.datasets.Cityscapes(
root, split=split, target_type=targets_combo, mode=mode)
output = multiple_types_dataset[0]
self.assertTrue(isinstance(output, tuple))
self.assertTrue(len(output) == 2)
self.assertTrue(isinstance(output[0], PIL.Image.Image))
self.assertTrue(isinstance(output[1], tuple))
self.assertTrue(len(output[1]) == 3)
self.assertTrue(isinstance(output[1][0], PIL.Image.Image)) # semantic
self.assertTrue(isinstance(output[1][1], dict)) # polygon
self.assertTrue(isinstance(output[1][2], PIL.Image.Image)) # color
@mock.patch('torchvision.datasets.SVHN._check_integrity')
@unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
def test_svhn(self, mock_check):
mock_check.return_value = True
with svhn_root() as root:
dataset = torchvision.datasets.SVHN(root, split="train")
self.generic_classification_dataset_test(dataset, num_images=2)
dataset = torchvision.datasets.SVHN(root, split="test")
self.generic_classification_dataset_test(dataset, num_images=2)
dataset = torchvision.datasets.SVHN(root, split="extra")
self.generic_classification_dataset_test(dataset, num_images=2)
@mock.patch('torchvision.datasets.voc.download_extract')
def test_voc_parse_xml(self, mock_download_extract):
with voc_root() as root:
dataset = torchvision.datasets.VOCDetection(root)
single_object_xml = """
"""
multiple_object_xml = """
"""
single_object_parsed = dataset.parse_voc_xml(ET.fromstring(single_object_xml))
multiple_object_parsed = dataset.parse_voc_xml(ET.fromstring(multiple_object_xml))
self.assertEqual(single_object_parsed, {'annotation': {'object': [{'name': 'cat'}]}})
self.assertEqual(multiple_object_parsed,
{'annotation': {
'object': [{
'name': 'cat'
}, {
'name': 'dog'
}]
}})
@unittest.skipIf(not HAS_PYAV, "PyAV unavailable")
def test_ucf101(self):
cached_meta_data = None
with ucf101_root() as (root, ann_root):
for split in {True, False}:
for fold in range(1, 4):
for length in {10, 15, 20}:
dataset = torchvision.datasets.UCF101(root, ann_root, length, fold=fold, train=split,
num_workers=2, _precomputed_metadata=cached_meta_data)
if cached_meta_data is None:
cached_meta_data = dataset.metadata
self.assertGreater(len(dataset), 0)
video, audio, label = dataset[0]
self.assertEqual(video.size(), (length, 320, 240, 3))
self.assertEqual(audio.numel(), 0)
self.assertEqual(label, 0)
video, audio, label = dataset[len(dataset) - 1]
self.assertEqual(video.size(), (length, 320, 240, 3))
self.assertEqual(audio.numel(), 0)
self.assertEqual(label, 1)
def test_places365(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"]))
def test_places365_transforms(self):
expected_image = "image"
expected_target = "target"
def transform(image):
return expected_image
def target_transform(target):
return expected_target
with places365_root() as places365:
root, data = places365
dataset = torchvision.datasets.Places365(
root, transform=transform, target_transform=target_transform, download=True
)
actual_image, actual_target = dataset[0]
self.assertEqual(actual_image, expected_image)
self.assertEqual(actual_target, expected_target)
def test_places365_devkit_download(self):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split) as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, split=split, download=True)
with self.subTest("classes"):
self.assertSequenceEqual(dataset.classes, data["classes"])
with self.subTest("class_to_idx"):
self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"])
with self.subTest("imgs"):
self.assertSequenceEqual(dataset.imgs, data["imgs"])
def test_places365_devkit_no_download(self):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split) as places365:
root, data = places365
with self.assertRaises(RuntimeError):
torchvision.datasets.Places365(root, split=split, download=False)
def test_places365_images_download(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with self.subTest(split=split, small=small):
with places365_root(split=split, small=small) as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
assert all(os.path.exists(item[0]) for item in dataset.imgs)
def test_places365_images_download_preexisting(self):
split = "train-standard"
small = False
images_dir = "data_large_standard"
with places365_root(split=split, small=small) as places365:
root, data = places365
os.mkdir(os.path.join(root, images_dir))
with self.assertRaises(RuntimeError):
torchvision.datasets.Places365(root, split=split, small=small, download=True)
def test_places365_repr_smoke(self):
with places365_root() as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, download=True)
self.assertIsInstance(repr(dataset), str)
class STL10Tester(DatasetTestcase):
@contextlib.contextmanager
def mocked_root(self):
with stl10_root() as (root, data):
yield root, data
@contextlib.contextmanager
def mocked_dataset(self, pre_extract=False, download=True, **kwargs):
with self.mocked_root() as (root, data):
if pre_extract:
utils.extract_archive(os.path.join(root, data["archive"]))
dataset = torchvision.datasets.STL10(root, download=download, **kwargs)
yield dataset, data
def test_not_found(self):
with self.assertRaises(RuntimeError):
with self.mocked_dataset(download=False):
pass
def test_splits(self):
for split in ('train', 'train+unlabeled', 'unlabeled', 'test'):
with self.mocked_dataset(split=split) as (dataset, data):
num_images = sum([data["num_images_in_split"][part] for part in split.split("+")])
self.generic_classification_dataset_test(dataset, num_images=num_images)
def test_folds(self):
for fold in range(10):
with self.mocked_dataset(split="train", folds=fold) as (dataset, data):
num_images = data["num_images_in_folds"][fold]
self.assertEqual(len(dataset), num_images)
def test_invalid_folds1(self):
with self.assertRaises(ValueError):
with self.mocked_dataset(folds=10):
pass
def test_invalid_folds2(self):
with self.assertRaises(ValueError):
with self.mocked_dataset(folds="0"):
pass
def test_transforms(self):
expected_image = "image"
expected_target = "target"
def transform(image):
return expected_image
def target_transform(target):
return expected_target
with self.mocked_dataset(transform=transform, target_transform=target_transform) as (dataset, _):
actual_image, actual_target = dataset[0]
self.assertEqual(actual_image, expected_image)
self.assertEqual(actual_target, expected_target)
def test_unlabeled(self):
with self.mocked_dataset(split="unlabeled") as (dataset, _):
labels = [dataset[idx][1] for idx in range(len(dataset))]
self.assertTrue(all([label == -1 for label in labels]))
@unittest.mock.patch("torchvision.datasets.stl10.download_and_extract_archive")
def test_download_preexisting(self, mock):
with self.mocked_dataset(pre_extract=True) as (dataset, data):
mock.assert_not_called()
def test_repr_smoke(self):
with self.mocked_dataset() as (dataset, _):
self.assertIsInstance(repr(dataset), str)
class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Caltech101
FEATURE_TYPES = (PIL.Image.Image, (int, np.ndarray, tuple))
CONFIGS = datasets_utils.combinations_grid(target_type=("category", "annotation", ["category", "annotation"]))
REQUIRED_PACKAGES = ("scipy",)
def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir) / "caltech101"
images = root / "101_ObjectCategories"
annotations = root / "Annotations"
categories = (("Faces", "Faces_2"), ("helicopter", "helicopter"), ("ying_yang", "ying_yang"))
num_images_per_category = 2
for image_category, annotation_category in categories:
datasets_utils.create_image_folder(
root=images,
name=image_category,
file_name_fn=lambda idx: f"image_{idx + 1:04d}.jpg",
num_examples=num_images_per_category,
)
self._create_annotation_folder(
root=annotations,
name=annotation_category,
file_name_fn=lambda idx: f"annotation_{idx + 1:04d}.mat",
num_examples=num_images_per_category,
)
# This is included in the original archive, but is removed by the dataset. Thus, an empty directory suffices.
os.makedirs(images / "BACKGROUND_Google")
return num_images_per_category * len(categories)
def _create_annotation_folder(self, root, name, file_name_fn, num_examples):
root = pathlib.Path(root) / name
os.makedirs(root)
for idx in range(num_examples):
self._create_annotation_file(root, file_name_fn(idx))
def _create_annotation_file(self, root, name):
mdict = dict(obj_contour=torch.rand((2, torch.randint(3, 6, size=())), dtype=torch.float64).numpy())
datasets_utils.lazy_importer.scipy.io.savemat(str(pathlib.Path(root) / name), mdict)
def test_combined_targets(self):
target_types = ["category", "annotation"]
individual_targets = []
for target_type in target_types:
with self.create_dataset(target_type=target_type) as (dataset, _):
_, target = dataset[0]
individual_targets.append(target)
with self.create_dataset(target_type=target_types) as (dataset, _):
_, combined_targets = dataset[0]
actual = len(individual_targets)
expected = len(combined_targets)
self.assertEqual(
actual,
expected,
f"The number of the returned combined targets does not match the the number targets if requested "
f"individually: {actual} != {expected}",
)
for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets):
with self.subTest(target_type=target_type):
actual = type(combined_target)
expected = type(individual_target)
self.assertIs(
actual,
expected,
f"Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",
)
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Caltech256
def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) / "caltech256" / "256_ObjectCategories"
categories = ((1, "ak47"), (127, "laptop-101"), (257, "clutter"))
num_images_per_category = 2
for idx, category in categories:
datasets_utils.create_image_folder(
tmpdir,
name=f"{idx:03d}.{category}",
file_name_fn=lambda image_idx: f"{idx:03d}_{image_idx + 1:04d}.jpg",
num_examples=num_images_per_category,
)
return num_images_per_category * len(categories)
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CIFAR10
CONFIGS = datasets_utils.combinations_grid(train=(True, False))
_VERSION_CONFIG = dict(
base_folder="cifar-10-batches-py",
train_files=tuple(f"data_batch_{idx}" for idx in range(1, 6)),
test_files=("test_batch",),
labels_key="labels",
meta_file="batches.meta",
num_categories=10,
categories_key="label_names",
)
def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) / self._VERSION_CONFIG["base_folder"]
os.makedirs(tmpdir)
num_images_per_file = 1
for name in itertools.chain(self._VERSION_CONFIG["train_files"], self._VERSION_CONFIG["test_files"]):
self._create_batch_file(tmpdir, name, num_images_per_file)
categories = self._create_meta_file(tmpdir)
return dict(
num_examples=num_images_per_file
* len(self._VERSION_CONFIG["train_files"] if config["train"] else self._VERSION_CONFIG["test_files"]),
categories=categories,
)
def _create_batch_file(self, root, name, num_images):
data = datasets_utils.create_image_or_video_tensor((num_images, 32 * 32 * 3))
labels = np.random.randint(0, self._VERSION_CONFIG["num_categories"], size=num_images).tolist()
self._create_binary_file(root, name, {"data": data, self._VERSION_CONFIG["labels_key"]: labels})
def _create_meta_file(self, root):
categories = [
f"{idx:0{len(str(self._VERSION_CONFIG['num_categories'] - 1))}d}"
for idx in range(self._VERSION_CONFIG["num_categories"])
]
self._create_binary_file(
root, self._VERSION_CONFIG["meta_file"], {self._VERSION_CONFIG["categories_key"]: categories}
)
return categories
def _create_binary_file(self, root, name, content):
with open(pathlib.Path(root) / name, "wb") as fh:
pickle.dump(content, fh)
def test_class_to_idx(self):
with self.create_dataset() as (dataset, info):
expected = {category: label for label, category in enumerate(info["categories"])}
actual = dataset.class_to_idx
self.assertEqual(actual, expected)
class CIFAR100(CIFAR10TestCase):
DATASET_CLASS = datasets.CIFAR100
_VERSION_CONFIG = dict(
base_folder="cifar-100-python",
train_files=("train",),
test_files=("test",),
labels_key="fine_labels",
meta_file="meta",
num_categories=100,
categories_key="fine_label_names",
)
class CelebATestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CelebA
FEATURE_TYPES = (PIL.Image.Image, (torch.Tensor, int, tuple, type(None)))
CONFIGS = datasets_utils.combinations_grid(
split=("train", "valid", "test", "all"),
target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]),
)
REQUIRED_PACKAGES = ("pandas",)
_SPLIT_TO_IDX = dict(train=0, valid=1, test=2)
def inject_fake_data(self, tmpdir, config):
base_folder = pathlib.Path(tmpdir) / "celeba"
os.makedirs(base_folder)
num_images, num_images_per_split = self._create_split_txt(base_folder)
datasets_utils.create_image_folder(
base_folder, "img_align_celeba", lambda idx: f"{idx + 1:06d}.jpg", num_images
)
attr_names = self._create_attr_txt(base_folder, num_images)
self._create_identity_txt(base_folder, num_images)
self._create_bbox_txt(base_folder, num_images)
self._create_landmarks_txt(base_folder, num_images)
return dict(num_examples=num_images_per_split[config["split"]], attr_names=attr_names)
def _create_split_txt(self, root):
num_images_per_split = dict(train=3, valid=2, test=1)
data = [
[self._SPLIT_TO_IDX[split]] for split, num_images in num_images_per_split.items() for _ in range(num_images)
]
self._create_txt(root, "list_eval_partition.txt", data)
num_images_per_split["all"] = num_images = sum(num_images_per_split.values())
return num_images, num_images_per_split
def _create_attr_txt(self, root, num_images):
header = ("5_o_Clock_Shadow", "Young")
data = torch.rand((num_images, len(header))).ge(0.5).int().mul(2).sub(1).tolist()
self._create_txt(root, "list_attr_celeba.txt", data, header=header, add_num_examples=True)
return header
def _create_identity_txt(self, root, num_images):
data = torch.randint(1, 4, size=(num_images, 1)).tolist()
self._create_txt(root, "identity_CelebA.txt", data)
def _create_bbox_txt(self, root, num_images):
header = ("x_1", "y_1", "width", "height")
data = torch.randint(10, size=(num_images, len(header))).tolist()
self._create_txt(
root, "list_bbox_celeba.txt", data, header=header, add_num_examples=True, add_image_id_to_header=True
)
def _create_landmarks_txt(self, root, num_images):
header = ("lefteye_x", "rightmouth_y")
data = torch.randint(10, size=(num_images, len(header))).tolist()
self._create_txt(root, "list_landmarks_align_celeba.txt", data, header=header, add_num_examples=True)
def _create_txt(self, root, name, data, header=None, add_num_examples=False, add_image_id_to_header=False):
with open(pathlib.Path(root) / name, "w") as fh:
if add_num_examples:
fh.write(f"{len(data)}\n")
if header:
if add_image_id_to_header:
header = ("image_id", *header)
fh.write(f"{' '.join(header)}\n")
for idx, line in enumerate(data, 1):
fh.write(f"{' '.join((f'{idx:06d}.jpg', *[str(value) for value in line]))}\n")
def test_combined_targets(self):
target_types = ["attr", "identity", "bbox", "landmarks"]
individual_targets = []
for target_type in target_types:
with self.create_dataset(target_type=target_type) as (dataset, _):
_, target = dataset[0]
individual_targets.append(target)
with self.create_dataset(target_type=target_types) as (dataset, _):
_, combined_targets = dataset[0]
actual = len(individual_targets)
expected = len(combined_targets)
self.assertEqual(
actual,
expected,
f"The number of the returned combined targets does not match the the number targets if requested "
f"individually: {actual} != {expected}",
)
for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets):
with self.subTest(target_type=target_type):
actual = type(combined_target)
expected = type(individual_target)
self.assertIs(
actual,
expected,
f"Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",
)
def test_no_target(self):
with self.create_dataset(target_type=[]) as (dataset, _):
_, target = dataset[0]
self.assertIsNone(target)
def test_attr_names(self):
with self.create_dataset() as (dataset, info):
self.assertEqual(tuple(dataset.attr_names), info["attr_names"])
if __name__ == "__main__":
unittest.main()