Unverified Commit 40333c5a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

celba (#2522)

parent 31245cb8
......@@ -2,6 +2,7 @@ from functools import partial
import torch
import os
import PIL
from typing import Any, Callable, List, Optional, Union, Tuple
from .vision import VisionDataset
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg
......@@ -48,8 +49,15 @@ class CelebA(VisionDataset):
("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
]
def __init__(self, root, split="train", target_type="attr", transform=None,
target_transform=None, download=False):
def __init__(
self,
root: str,
split: str = "train",
target_type: Union[List[str], str] = "attr",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
import pandas
super(CelebA, self).__init__(root, transform=transform,
target_transform=target_transform)
......@@ -75,8 +83,8 @@ class CelebA(VisionDataset):
"test": 2,
"all": None,
}
split = split_map[verify_str_arg(split.lower(), "split",
("train", "valid", "test", "all"))]
split_ = split_map[verify_str_arg(split.lower(), "split",
("train", "valid", "test", "all"))]
fn = partial(os.path.join, self.root, self.base_folder)
splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
......@@ -85,7 +93,7 @@ class CelebA(VisionDataset):
landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1)
attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1)
mask = slice(None) if split is None else (splits[1] == split)
mask = slice(None) if split_ is None else (splits[1] == split_)
self.filename = splits[mask].index.values
self.identity = torch.as_tensor(identity[mask].values)
......@@ -95,7 +103,7 @@ class CelebA(VisionDataset):
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
self.attr_names = list(attr.columns)
def _check_integrity(self):
def _check_integrity(self) -> bool:
for (_, md5, filename) in self.file_list:
fpath = os.path.join(self.root, self.base_folder, filename)
_, ext = os.path.splitext(filename)
......@@ -107,7 +115,7 @@ class CelebA(VisionDataset):
# Should check a hash of the images
return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
def download(self):
def download(self) -> None:
import zipfile
if self._check_integrity():
......@@ -120,10 +128,10 @@ class CelebA(VisionDataset):
with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
f.extractall(os.path.join(self.root, self.base_folder))
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
target = []
target: Any = []
for t in self.target_type:
if t == "attr":
target.append(self.attr[index, :])
......@@ -150,9 +158,9 @@ class CelebA(VisionDataset):
return X, target
def __len__(self):
def __len__(self) -> int:
return len(self.attr)
def extra_repr(self):
def extra_repr(self) -> str:
lines = ["Target type: {target_type}", "Split: {split}"]
return '\n'.join(lines).format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment