Unverified Commit cadb1681 authored by Vassilis C. Nicodemou's avatar Vassilis C. Nicodemou Committed by GitHub
Browse files

Fix splitting CelebA dataset (#4377)


Co-authored-by: default avatarVassilis Nicodemou <nikodim@ics.forth.gr>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent de3e9091
......@@ -512,7 +512,7 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
return dict(num_examples=num_images_per_split[config["split"]], attr_names=attr_names)
def _create_split_txt(self, root):
num_images_per_split = dict(train=3, valid=2, test=1)
num_images_per_split = dict(train=4, valid=3, test=2)
data = [
[self._SPLIT_TO_IDX[split]] for split, num_images in num_images_per_split.items() for _ in range(num_images)
......@@ -595,6 +595,17 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
with self.create_dataset() as (dataset, info):
assert tuple(dataset.attr_names) == info["attr_names"]
def test_images_names_split(self):
with self.create_dataset(split='all') as (dataset, _):
all_imgs_names = set(dataset.filename)
merged_imgs_names = set()
for split in ["train", "valid", "test"]:
with self.create_dataset(split=split) as (dataset, _):
merged_imgs_names.update(dataset.filename)
assert merged_imgs_names == all_imgs_names
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.VOCSegmentation
......
......@@ -99,7 +99,10 @@ class CelebA(VisionDataset):
mask = slice(None) if split_ is None else (splits.data == split_).squeeze()
self.filename = splits.index
if mask == slice(None): # if split == "all"
self.filename = splits.index
else:
self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]
self.identity = identity.data[mask]
self.bbox = bbox.data[mask]
self.landmarks_align = landmarks_align.data[mask]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment