from functools import partial import torch import os import PIL from .vision import VisionDataset from .utils import download_file_from_google_drive, check_integrity class CelebA(VisionDataset): """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. Args: root (string): Root directory where images are downloaded to. split (string): One of {'train', 'valid', 'test', 'all'}. Accordingly dataset is selected. target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, or ``landmarks``. Can also be a list to output a tuple with all specified target types. The targets represent: ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes ``identity`` (int): label for each person (data points with the same identity are the same person) ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) Defaults to ``attr``. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.ToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ base_folder = "celeba" # There currently does not appear to be a easy way to extract 7z in python (without introducing additional # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available # right now. file_list = [ # File ID MD5 Hash Filename ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), ] def __init__(self, root, split="train", target_type="attr", transform=None, target_transform=None, download=False): import pandas super(CelebA, self).__init__(root) self.split = split if isinstance(target_type, list): self.target_type = target_type else: self.target_type = [target_type] self.transform = transform self.target_transform = target_transform if download: self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') self.transform = transform self.target_transform = target_transform if split.lower() == "train": split = 0 elif split.lower() == "valid": split = 1 elif split.lower() == "test": split = 2 elif split.lower() == "all": split = None else: raise ValueError('Wrong split entered! Please use "train", ' '"valid", "test", or "all"') fn = partial(os.path.join, self.root, self.base_folder) splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0) identity = pandas.read_csv(fn("identity_CelebA.txt"), delim_whitespace=True, header=None, index_col=0) bbox = pandas.read_csv(fn("list_bbox_celeba.txt"), delim_whitespace=True, header=1, index_col=0) landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1) attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1) mask = slice(None) if split is None else (splits[1] == split) self.filename = splits[mask].index.values self.identity = torch.as_tensor(identity[mask].values) self.bbox = torch.as_tensor(bbox[mask].values) self.landmarks_align = torch.as_tensor(landmarks_align[mask].values) self.attr = torch.as_tensor(attr[mask].values) self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} self.attr_names = list(attr.columns) def _check_integrity(self): for (_, md5, filename) in self.file_list: fpath = os.path.join(self.root, self.base_folder, filename) _, ext = os.path.splitext(filename) # Allow original archive to be deleted (zip and 7z) # Only need the extracted images if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): return False # Should check a hash of the images return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) def download(self): import zipfile if self._check_integrity(): print('Files already downloaded and verified') return for (file_id, md5, filename) in self.file_list: download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: f.extractall(os.path.join(self.root, self.base_folder)) def __getitem__(self, index): X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) target = [] for t in self.target_type: if t == "attr": target.append(self.attr[index, :]) elif t == "identity": target.append(self.identity[index, 0]) elif t == "bbox": target.append(self.bbox[index, :]) elif t == "landmarks": target.append(self.landmarks_align[index, :]) else: raise ValueError("Target type \"{}\" is not recognized.".format(t)) target = tuple(target) if len(target) > 1 else target[0] if self.transform is not None: X = self.transform(X) if self.target_transform is not None: target = self.target_transform(target) return X, target def __len__(self): return len(self.attr) def extra_repr(self): lines = ["Target type: {target_type}", "Split: {split}"] return '\n'.join(lines).format(**self.__dict__)