"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "658e24e86c4c52ee14244ab7a7113f5bf353186e"
Unverified Commit 40333c5a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

celba (#2522)

parent 31245cb8
...@@ -2,6 +2,7 @@ from functools import partial ...@@ -2,6 +2,7 @@ from functools import partial
import torch import torch
import os import os
import PIL import PIL
from typing import Any, Callable, List, Optional, Union, Tuple
from .vision import VisionDataset from .vision import VisionDataset
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg from .utils import download_file_from_google_drive, check_integrity, verify_str_arg
...@@ -48,8 +49,15 @@ class CelebA(VisionDataset): ...@@ -48,8 +49,15 @@ class CelebA(VisionDataset):
("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
] ]
def __init__(self, root, split="train", target_type="attr", transform=None, def __init__(
target_transform=None, download=False): self,
root: str,
split: str = "train",
target_type: Union[List[str], str] = "attr",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
import pandas import pandas
super(CelebA, self).__init__(root, transform=transform, super(CelebA, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
...@@ -75,8 +83,8 @@ class CelebA(VisionDataset): ...@@ -75,8 +83,8 @@ class CelebA(VisionDataset):
"test": 2, "test": 2,
"all": None, "all": None,
} }
split = split_map[verify_str_arg(split.lower(), "split", split_ = split_map[verify_str_arg(split.lower(), "split",
("train", "valid", "test", "all"))] ("train", "valid", "test", "all"))]
fn = partial(os.path.join, self.root, self.base_folder) fn = partial(os.path.join, self.root, self.base_folder)
splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0) splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
...@@ -85,7 +93,7 @@ class CelebA(VisionDataset): ...@@ -85,7 +93,7 @@ class CelebA(VisionDataset):
landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1) landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1)
attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1) attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1)
mask = slice(None) if split is None else (splits[1] == split) mask = slice(None) if split_ is None else (splits[1] == split_)
self.filename = splits[mask].index.values self.filename = splits[mask].index.values
self.identity = torch.as_tensor(identity[mask].values) self.identity = torch.as_tensor(identity[mask].values)
...@@ -95,7 +103,7 @@ class CelebA(VisionDataset): ...@@ -95,7 +103,7 @@ class CelebA(VisionDataset):
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
self.attr_names = list(attr.columns) self.attr_names = list(attr.columns)
def _check_integrity(self): def _check_integrity(self) -> bool:
for (_, md5, filename) in self.file_list: for (_, md5, filename) in self.file_list:
fpath = os.path.join(self.root, self.base_folder, filename) fpath = os.path.join(self.root, self.base_folder, filename)
_, ext = os.path.splitext(filename) _, ext = os.path.splitext(filename)
...@@ -107,7 +115,7 @@ class CelebA(VisionDataset): ...@@ -107,7 +115,7 @@ class CelebA(VisionDataset):
# Should check a hash of the images # Should check a hash of the images
return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
def download(self): def download(self) -> None:
import zipfile import zipfile
if self._check_integrity(): if self._check_integrity():
...@@ -120,10 +128,10 @@ class CelebA(VisionDataset): ...@@ -120,10 +128,10 @@ class CelebA(VisionDataset):
with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
f.extractall(os.path.join(self.root, self.base_folder)) f.extractall(os.path.join(self.root, self.base_folder))
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
target = [] target: Any = []
for t in self.target_type: for t in self.target_type:
if t == "attr": if t == "attr":
target.append(self.attr[index, :]) target.append(self.attr[index, :])
...@@ -150,9 +158,9 @@ class CelebA(VisionDataset): ...@@ -150,9 +158,9 @@ class CelebA(VisionDataset):
return X, target return X, target
def __len__(self): def __len__(self) -> int:
return len(self.attr) return len(self.attr)
def extra_repr(self): def extra_repr(self) -> str:
lines = ["Target type: {target_type}", "Split: {split}"] lines = ["Target type: {target_type}", "Split: {split}"]
return '\n'.join(lines).format(**self.__dict__) return '\n'.join(lines).format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment