Commit e4e167a3 authored by Dougal J. Sutherland's avatar Dougal J. Sutherland Committed by Francisco Massa
Browse files

CelebA: track attr names, support split="all", code cleanup (#1008)

* CelebA: track attr names, support split="all", code cleanup

* fix typo
parent b5db97b4
from functools import partial
import torch import torch
import os import os
import PIL import PIL
...@@ -10,7 +11,7 @@ class CelebA(VisionDataset): ...@@ -10,7 +11,7 @@ class CelebA(VisionDataset):
Args: Args:
root (string): Root directory where images are downloaded to. root (string): Root directory where images are downloaded to.
split (string): One of {'train', 'valid', 'test'}. split (string): One of {'train', 'valid', 'test', 'all'}.
Accordingly dataset is selected. Accordingly dataset is selected.
target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
or ``landmarks``. Can also be a list to output a tuple with all specified target types. or ``landmarks``. Can also be a list to output a tuple with all specified target types.
...@@ -78,32 +79,28 @@ class CelebA(VisionDataset): ...@@ -78,32 +79,28 @@ class CelebA(VisionDataset):
split = 1 split = 1
elif split.lower() == "test": elif split.lower() == "test":
split = 2 split = 2
elif split.lower() == "all":
split = None
else: else:
raise ValueError('Wrong split entered! Please use split="train" ' raise ValueError('Wrong split entered! Please use "train", '
'or split="valid" or split="test"') '"valid", "test", or "all"')
with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f: fn = partial(os.path.join, self.root, self.base_folder)
splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
identity = pandas.read_csv(fn("identity_CelebA.txt"), delim_whitespace=True, header=None, index_col=0)
bbox = pandas.read_csv(fn("list_bbox_celeba.txt"), delim_whitespace=True, header=1, index_col=0)
landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1)
attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1)
with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f: mask = slice(None) if split is None else (splits[1] == split)
self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f:
self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0)
with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f:
self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1)
with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f:
self.attr = pandas.read_csv(f, delim_whitespace=True, header=1)
mask = (splits[1] == split)
self.filename = splits[mask].index.values self.filename = splits[mask].index.values
self.identity = torch.as_tensor(self.identity[mask].values) self.identity = torch.as_tensor(identity[mask].values)
self.bbox = torch.as_tensor(self.bbox[mask].values) self.bbox = torch.as_tensor(bbox[mask].values)
self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values) self.landmarks_align = torch.as_tensor(landmarks_align[mask].values)
self.attr = torch.as_tensor(self.attr[mask].values) self.attr = torch.as_tensor(attr[mask].values)
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
self.attr_names = list(attr.columns)
def _check_integrity(self): def _check_integrity(self):
for (_, md5, filename) in self.file_list: for (_, md5, filename) in self.file_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment