celeba.py 8.27 KB
Newer Older
1
import csv
2
import os
limm's avatar
limm committed
3
4
5
6
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

7
import PIL
limm's avatar
limm committed
8
9
10
import torch

from .utils import check_integrity, download_file_from_google_drive, extract_archive, verify_str_arg
11
from .vision import VisionDataset
12

13
14
CSV = namedtuple("CSV", ["header", "index", "data"])

15

16
class CelebA(VisionDataset):
17
18
19
    """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.

    Args:
limm's avatar
limm committed
20
        root (str or ``pathlib.Path``): Root directory where images are downloaded to.
21
        split (string): One of {'train', 'valid', 'test', 'all'}.
22
23
24
25
            Accordingly dataset is selected.
        target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
            or ``landmarks``. Can also be a list to output a tuple with all specified target types.
            The targets represent:
26

limm's avatar
limm committed
27
                - ``attr`` (Tensor shape=(40,) dtype=int): binary (0, 1) labels for attributes
28
                - ``identity`` (int): label for each person (data points with the same identity are the same person)
limm's avatar
limm committed
29
30
                - ``bbox`` (Tensor shape=(4,) dtype=int): bounding box (x, y, width, height)
                - ``landmarks`` (Tensor shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
31
32
                  righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)

33
            Defaults to ``attr``. If empty, ``None`` will be returned as target.
34

limm's avatar
limm committed
35
36
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.PILToTensor``
37
38
39
40
41
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
limm's avatar
limm committed
42
43
44
45

            .. warning::

                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
46
47
48
    """

    base_folder = "celeba"
limm's avatar
limm committed
49
    # There currently does not appear to be an easy way to extract 7z in python (without introducing additional
50
51
52
    # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
    # right now.
    file_list = [
Aditya Oke's avatar
Aditya Oke committed
53
        # File ID                                      MD5 Hash                            Filename
54
        ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
Aditya Oke's avatar
Aditya Oke committed
55
        # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
56
57
58
59
60
61
62
63
64
        # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
        ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
        ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
        ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
        ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
        # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
        ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
    ]

Philip Meier's avatar
Philip Meier committed
65
    def __init__(
limm's avatar
limm committed
66
67
68
69
70
71
72
        self,
        root: Union[str, Path],
        split: str = "train",
        target_type: Union[List[str], str] = "attr",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
73
    ) -> None:
limm's avatar
limm committed
74
        super().__init__(root, transform=transform, target_transform=target_transform)
75
76
77
78
79
80
        self.split = split
        if isinstance(target_type, list):
            self.target_type = target_type
        else:
            self.target_type = [target_type]

81
        if not self.target_type and self.target_transform is not None:
limm's avatar
limm committed
82
            raise RuntimeError("target_transform is specified but target_type is empty")
83

84
85
86
87
        if download:
            self.download()

        if not self._check_integrity():
limm's avatar
limm committed
88
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
89

90
91
92
93
94
95
        split_map = {
            "train": 0,
            "valid": 1,
            "test": 2,
            "all": None,
        }
limm's avatar
limm committed
96
        split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
97
98
99
100
101
102
103
104
        splits = self._load_csv("list_eval_partition.txt")
        identity = self._load_csv("identity_CelebA.txt")
        bbox = self._load_csv("list_bbox_celeba.txt", header=1)
        landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
        attr = self._load_csv("list_attr_celeba.txt", header=1)

        mask = slice(None) if split_ is None else (splits.data == split_).squeeze()

limm's avatar
limm committed
105
106
107
108
        if mask == slice(None):  # if split == "all"
            self.filename = splits.index
        else:
            self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]
109
110
111
112
        self.identity = identity.data[mask]
        self.bbox = bbox.data[mask]
        self.landmarks_align = landmarks_align.data[mask]
        self.attr = attr.data[mask]
113
        # map from {-1, 1} to {0, 1}
limm's avatar
limm committed
114
        self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor")
115
116
117
118
119
120
121
        self.attr_names = attr.header

    def _load_csv(
        self,
        filename: str,
        header: Optional[int] = None,
    ) -> CSV:
limm's avatar
limm committed
122
123
        with open(os.path.join(self.root, self.base_folder, filename)) as csv_file:
            data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True))
124
125
126

        if header is not None:
            headers = data[header]
limm's avatar
limm committed
127
128
129
            data = data[header + 1 :]
        else:
            headers = []
130
131
132
133
134
135

        indices = [row[0] for row in data]
        data = [row[1:] for row in data]
        data_int = [list(map(int, i)) for i in data]

        return CSV(headers, indices, torch.tensor(data_int))
136

Philip Meier's avatar
Philip Meier committed
137
    def _check_integrity(self) -> bool:
138
139
140
141
142
143
144
145
146
147
148
        for (_, md5, filename) in self.file_list:
            fpath = os.path.join(self.root, self.base_folder, filename)
            _, ext = os.path.splitext(filename)
            # Allow original archive to be deleted (zip and 7z)
            # Only need the extracted images
            if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
                return False

        # Should check a hash of the images
        return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))

Philip Meier's avatar
Philip Meier committed
149
    def download(self) -> None:
150
        if self._check_integrity():
limm's avatar
limm committed
151
            print("Files already downloaded and verified")
152
153
154
155
156
            return

        for (file_id, md5, filename) in self.file_list:
            download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)

limm's avatar
limm committed
157
        extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
158

Philip Meier's avatar
Philip Meier committed
159
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
160
161
        X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

Philip Meier's avatar
Philip Meier committed
162
        target: Any = []
163
164
165
166
167
168
169
170
171
172
        for t in self.target_type:
            if t == "attr":
                target.append(self.attr[index, :])
            elif t == "identity":
                target.append(self.identity[index, 0])
            elif t == "bbox":
                target.append(self.bbox[index, :])
            elif t == "landmarks":
                target.append(self.landmarks_align[index, :])
            else:
173
                # TODO: refactor with utils.verify_str_arg
limm's avatar
limm committed
174
                raise ValueError(f'Target type "{t}" is not recognized.')
175
176
177
178

        if self.transform is not None:
            X = self.transform(X)

179
180
181
182
183
184
185
        if target:
            target = tuple(target) if len(target) > 1 else target[0]

            if self.target_transform is not None:
                target = self.target_transform(target)
        else:
            target = None
186
187
188

        return X, target

Philip Meier's avatar
Philip Meier committed
189
    def __len__(self) -> int:
190
191
        return len(self.attr)

Philip Meier's avatar
Philip Meier committed
192
    def extra_repr(self) -> str:
193
        lines = ["Target type: {target_type}", "Split: {split}"]
limm's avatar
limm committed
194
        return "\n".join(lines).format(**self.__dict__)