celeba.py 8.1 KB
Newer Older
1
import csv
2
import os
3
from collections import namedtuple
Philip Meier's avatar
Philip Meier committed
4
from typing import Any, Callable, List, Optional, Union, Tuple
5
6
7
8

import PIL
import torch

Nicolas Hug's avatar
Nicolas Hug committed
9
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive
10
from .vision import VisionDataset
11

12
13
CSV = namedtuple("CSV", ["header", "index", "data"])

14

15
class CelebA(VisionDataset):
16
17
18
19
    """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.

    Args:
        root (string): Root directory where images are downloaded to.
20
        split (string): One of {'train', 'valid', 'test', 'all'}.
21
22
23
24
            Accordingly dataset is selected.
        target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
            or ``landmarks``. Can also be a list to output a tuple with all specified target types.
            The targets represent:
25
26
27
28
29
30
31

                - ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
                - ``identity`` (int): label for each person (data points with the same identity are the same person)
                - ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
                - ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
                  righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)

32
            Defaults to ``attr``. If empty, ``None`` will be returned as target.
33

34
        transform (callable, optional): A function/transform that  takes in an PIL image
35
            and returns a transformed version. E.g, ``transforms.PILToTensor``
36
37
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
Nicolas Hug's avatar
Nicolas Hug committed
38
39
40
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
41
42
43
44
45
46
47
    """

    base_folder = "celeba"
    # There currently does not appear to be a easy way to extract 7z in python (without introducing additional
    # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
    # right now.
    file_list = [
Aditya Oke's avatar
Aditya Oke committed
48
        # File ID                                      MD5 Hash                            Filename
49
        ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
Aditya Oke's avatar
Aditya Oke committed
50
        # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
51
52
53
54
55
56
57
58
59
        # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
        ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
        ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
        ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
        ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
        # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
        ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
    ]

Philip Meier's avatar
Philip Meier committed
60
    def __init__(
61
62
63
64
65
66
        self,
        root: str,
        split: str = "train",
        target_type: Union[List[str], str] = "attr",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
Nicolas Hug's avatar
Nicolas Hug committed
67
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
68
    ) -> None:
69
        super().__init__(root, transform=transform, target_transform=target_transform)
70
71
72
73
74
75
        self.split = split
        if isinstance(target_type, list):
            self.target_type = target_type
        else:
            self.target_type = [target_type]

76
        if not self.target_type and self.target_transform is not None:
77
            raise RuntimeError("target_transform is specified but target_type is empty")
78

79
80
81
82
        if download:
            self.download()

        if not self._check_integrity():
83
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
84

85
86
87
88
89
90
        split_map = {
            "train": 0,
            "valid": 1,
            "test": 2,
            "all": None,
        }
91
        split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
92
93
94
95
96
97
98
99
        splits = self._load_csv("list_eval_partition.txt")
        identity = self._load_csv("identity_CelebA.txt")
        bbox = self._load_csv("list_bbox_celeba.txt", header=1)
        landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
        attr = self._load_csv("list_attr_celeba.txt", header=1)

        mask = slice(None) if split_ is None else (splits.data == split_).squeeze()

100
101
102
103
        if mask == slice(None):  # if split == "all"
            self.filename = splits.index
        else:
            self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]
104
105
106
107
        self.identity = identity.data[mask]
        self.bbox = bbox.data[mask]
        self.landmarks_align = landmarks_align.data[mask]
        self.attr = attr.data[mask]
108
        # map from {-1, 1} to {0, 1}
109
        self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor")
110
111
112
113
114
115
116
        self.attr_names = attr.header

    def _load_csv(
        self,
        filename: str,
        header: Optional[int] = None,
    ) -> CSV:
117
        with open(os.path.join(self.root, self.base_folder, filename)) as csv_file:
118
            data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True))
119
120
121

        if header is not None:
            headers = data[header]
122
            data = data[header + 1 :]
123
124
        else:
            headers = []
125
126
127
128
129
130

        indices = [row[0] for row in data]
        data = [row[1:] for row in data]
        data_int = [list(map(int, i)) for i in data]

        return CSV(headers, indices, torch.tensor(data_int))
131

Philip Meier's avatar
Philip Meier committed
132
    def _check_integrity(self) -> bool:
133
134
135
136
137
138
139
140
141
142
143
        for (_, md5, filename) in self.file_list:
            fpath = os.path.join(self.root, self.base_folder, filename)
            _, ext = os.path.splitext(filename)
            # Allow original archive to be deleted (zip and 7z)
            # Only need the extracted images
            if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
                return False

        # Should check a hash of the images
        return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))

Philip Meier's avatar
Philip Meier committed
144
    def download(self) -> None:
145
        if self._check_integrity():
146
            print("Files already downloaded and verified")
147
148
            return

Nicolas Hug's avatar
Nicolas Hug committed
149
150
151
152
        for (file_id, md5, filename) in self.file_list:
            download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)

        extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
153

Philip Meier's avatar
Philip Meier committed
154
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
155
156
        X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

Philip Meier's avatar
Philip Meier committed
157
        target: Any = []
158
159
160
161
162
163
164
165
166
167
        for t in self.target_type:
            if t == "attr":
                target.append(self.attr[index, :])
            elif t == "identity":
                target.append(self.identity[index, 0])
            elif t == "bbox":
                target.append(self.bbox[index, :])
            elif t == "landmarks":
                target.append(self.landmarks_align[index, :])
            else:
168
                # TODO: refactor with utils.verify_str_arg
169
                raise ValueError(f'Target type "{t}" is not recognized.')
170
171
172
173

        if self.transform is not None:
            X = self.transform(X)

174
175
176
177
178
179
180
        if target:
            target = tuple(target) if len(target) > 1 else target[0]

            if self.target_transform is not None:
                target = self.target_transform(target)
        else:
            target = None
181
182
183

        return X, target

Philip Meier's avatar
Philip Meier committed
184
    def __len__(self) -> int:
185
186
        return len(self.attr)

Philip Meier's avatar
Philip Meier committed
187
    def extra_repr(self) -> str:
188
        lines = ["Target type: {target_type}", "Split: {split}"]
189
        return "\n".join(lines).format(**self.__dict__)