"docs/vscode:/vscode.git/clone" did not exist on "6e316588f87f5a428b0fc46adb505b28a189a96d"
celeba.py 8.97 KB
Newer Older
1
import csv
2
import os
3
import warnings
4
from collections import namedtuple
Philip Meier's avatar
Philip Meier committed
5
from typing import Any, Callable, List, Optional, Union, Tuple
6
7
8
9

import PIL
import torch

10
from .utils import check_integrity, verify_str_arg
11
from .vision import VisionDataset
12

13
14
CSV = namedtuple("CSV", ["header", "index", "data"])

15

16
class CelebA(VisionDataset):
17
18
19
20
    """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.

    Args:
        root (string): Root directory where images are downloaded to.
21
        split (string): One of {'train', 'valid', 'test', 'all'}.
22
23
24
25
            Accordingly dataset is selected.
        target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
            or ``landmarks``. Can also be a list to output a tuple with all specified target types.
            The targets represent:
26
27
28
29
30
31
32

                - ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
                - ``identity`` (int): label for each person (data points with the same identity are the same person)
                - ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
                - ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
                  righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)

33
            Defaults to ``attr``. If empty, ``None`` will be returned as target.
34

35
        transform (callable, optional): A function/transform that  takes in an PIL image
36
            and returns a transformed version. E.g, ``transforms.PILToTensor``
37
38
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
39
        download (bool, optional): Deprecated.
40
41
42

            .. warning::

43
44
                Downloading CelebA is not supported anymore as of 0.13 and this
                parameter will be removed in 0.15. See
45
46
47
48
49
                `this issue <https://github.com/pytorch/vision/issues/5705>`__
                for more details.
                Please download the files from
                https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract
                them in ``root/celeba``.
50
51
52
53
54
55
56
    """

    base_folder = "celeba"
    # There currently does not appear to be a easy way to extract 7z in python (without introducing additional
    # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
    # right now.
    file_list = [
Aditya Oke's avatar
Aditya Oke committed
57
        # File ID                                      MD5 Hash                            Filename
58
        ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
Aditya Oke's avatar
Aditya Oke committed
59
        # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
60
61
62
63
64
65
66
67
68
        # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
        ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
        ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
        ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
        ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
        # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
        ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
    ]

Philip Meier's avatar
Philip Meier committed
69
    def __init__(
70
71
72
73
74
75
        self,
        root: str,
        split: str = "train",
        target_type: Union[List[str], str] = "attr",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
76
        download: bool = None,
Philip Meier's avatar
Philip Meier committed
77
    ) -> None:
78
        super().__init__(root, transform=transform, target_transform=target_transform)
79
80
81
82
83
84
        self.split = split
        if isinstance(target_type, list):
            self.target_type = target_type
        else:
            self.target_type = [target_type]

85
        if not self.target_type and self.target_transform is not None:
86
            raise RuntimeError("target_transform is specified but target_type is empty")
87

88
89
90
91
92
93
94
95
96
        if download is not None:
            warnings.warn(
                "Downloading CelebA is not supported anymore as of 0.13, and the "
                "download parameter will be removed in 0.15. See "
                "https://github.com/pytorch/vision/issues/5705 for more details. "
                "Please download the files from "
                "https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them "
                "in ``root/celeba``."
            )
97
98
99
100
        if download:
            self.download()

        if not self._check_integrity():
101
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
102

103
104
105
106
107
108
        split_map = {
            "train": 0,
            "valid": 1,
            "test": 2,
            "all": None,
        }
109
        split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
110
111
112
113
114
115
116
117
        splits = self._load_csv("list_eval_partition.txt")
        identity = self._load_csv("identity_CelebA.txt")
        bbox = self._load_csv("list_bbox_celeba.txt", header=1)
        landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
        attr = self._load_csv("list_attr_celeba.txt", header=1)

        mask = slice(None) if split_ is None else (splits.data == split_).squeeze()

118
119
120
121
        if mask == slice(None):  # if split == "all"
            self.filename = splits.index
        else:
            self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]
122
123
124
125
        self.identity = identity.data[mask]
        self.bbox = bbox.data[mask]
        self.landmarks_align = landmarks_align.data[mask]
        self.attr = attr.data[mask]
126
        # map from {-1, 1} to {0, 1}
127
        self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor")
128
129
130
131
132
133
134
        self.attr_names = attr.header

    def _load_csv(
        self,
        filename: str,
        header: Optional[int] = None,
    ) -> CSV:
135
        with open(os.path.join(self.root, self.base_folder, filename)) as csv_file:
136
            data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True))
137
138
139

        if header is not None:
            headers = data[header]
140
            data = data[header + 1 :]
141
142
        else:
            headers = []
143
144
145
146
147
148

        indices = [row[0] for row in data]
        data = [row[1:] for row in data]
        data_int = [list(map(int, i)) for i in data]

        return CSV(headers, indices, torch.tensor(data_int))
149

Philip Meier's avatar
Philip Meier committed
150
    def _check_integrity(self) -> bool:
151
152
153
154
155
156
157
158
159
160
161
        for (_, md5, filename) in self.file_list:
            fpath = os.path.join(self.root, self.base_folder, filename)
            _, ext = os.path.splitext(filename)
            # Allow original archive to be deleted (zip and 7z)
            # Only need the extracted images
            if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
                return False

        # Should check a hash of the images
        return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))

Philip Meier's avatar
Philip Meier committed
162
    def download(self) -> None:
163
        if self._check_integrity():
164
            print("Files already downloaded and verified")
165
166
            return

167
        raise ValueError(
168
169
            "Downloading CelebA is not supported anymore as of 0.13, and the "
            "download parameter will be removed in 0.15. See "
170
171
172
173
174
            "https://github.com/pytorch/vision/issues/5705 for more details. "
            "Please download the files from "
            "https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them "
            "in ``root/celeba``."
        )
175

Philip Meier's avatar
Philip Meier committed
176
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
177
178
        X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

Philip Meier's avatar
Philip Meier committed
179
        target: Any = []
180
181
182
183
184
185
186
187
188
189
        for t in self.target_type:
            if t == "attr":
                target.append(self.attr[index, :])
            elif t == "identity":
                target.append(self.identity[index, 0])
            elif t == "bbox":
                target.append(self.bbox[index, :])
            elif t == "landmarks":
                target.append(self.landmarks_align[index, :])
            else:
190
                # TODO: refactor with utils.verify_str_arg
191
                raise ValueError(f'Target type "{t}" is not recognized.')
192
193
194
195

        if self.transform is not None:
            X = self.transform(X)

196
197
198
199
200
201
202
        if target:
            target = tuple(target) if len(target) > 1 else target[0]

            if self.target_transform is not None:
                target = self.target_transform(target)
        else:
            target = None
203
204
205

        return X, target

Philip Meier's avatar
Philip Meier committed
206
    def __len__(self) -> int:
207
208
        return len(self.attr)

Philip Meier's avatar
Philip Meier committed
209
    def extra_repr(self) -> str:
210
        lines = ["Target type: {target_type}", "Split: {split}"]
211
        return "\n".join(lines).format(**self.__dict__)