Unverified Commit 44460c9c authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Remove pandas dependency for CelebA dataset (#3656)



* Remove pandas dependecy for CelebA dataset

* address PR comments

* Apply suggestions from code review
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 1c1d5d87
...@@ -53,7 +53,6 @@ class LazyImporter: ...@@ -53,7 +53,6 @@ class LazyImporter:
MODULES = ( MODULES = (
"av", "av",
"lmdb", "lmdb",
"pandas",
"pycocotools", "pycocotools",
"requests", "requests",
"scipy.io", "scipy.io",
......
...@@ -616,7 +616,6 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase): ...@@ -616,7 +616,6 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
split=("train", "valid", "test", "all"), split=("train", "valid", "test", "all"),
target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]), target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]),
) )
REQUIRED_PACKAGES = ("pandas",)
_SPLIT_TO_IDX = dict(train=0, valid=1, test=2) _SPLIT_TO_IDX = dict(train=0, valid=1, test=2)
......
from collections import namedtuple
import csv
from functools import partial from functools import partial
import torch import torch
import os import os
...@@ -6,6 +8,8 @@ from typing import Any, Callable, List, Optional, Union, Tuple ...@@ -6,6 +8,8 @@ from typing import Any, Callable, List, Optional, Union, Tuple
from .vision import VisionDataset from .vision import VisionDataset
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg from .utils import download_file_from_google_drive, check_integrity, verify_str_arg
CSV = namedtuple("CSV", ["header", "index", "data"])
class CelebA(VisionDataset): class CelebA(VisionDataset):
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset. """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
...@@ -61,7 +65,6 @@ class CelebA(VisionDataset): ...@@ -61,7 +65,6 @@ class CelebA(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
import pandas
super(CelebA, self).__init__(root, transform=transform, super(CelebA, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
self.split = split self.split = split
...@@ -88,23 +91,42 @@ class CelebA(VisionDataset): ...@@ -88,23 +91,42 @@ class CelebA(VisionDataset):
} }
split_ = split_map[verify_str_arg(split.lower(), "split", split_ = split_map[verify_str_arg(split.lower(), "split",
("train", "valid", "test", "all"))] ("train", "valid", "test", "all"))]
splits = self._load_csv("list_eval_partition.txt")
identity = self._load_csv("identity_CelebA.txt")
bbox = self._load_csv("list_bbox_celeba.txt", header=1)
landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
attr = self._load_csv("list_attr_celeba.txt", header=1)
mask = slice(None) if split_ is None else (splits.data == split_).squeeze()
self.filename = splits.index
self.identity = identity.data[mask]
self.bbox = bbox.data[mask]
self.landmarks_align = landmarks_align.data[mask]
self.attr = attr.data[mask]
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
self.attr_names = attr.header
def _load_csv(
self,
filename: str,
header: Optional[int] = None,
) -> CSV:
data, indices, headers = [], [], []
fn = partial(os.path.join, self.root, self.base_folder) fn = partial(os.path.join, self.root, self.base_folder)
splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0) with open(fn(filename)) as csv_file:
identity = pandas.read_csv(fn("identity_CelebA.txt"), delim_whitespace=True, header=None, index_col=0) data = list(csv.reader(csv_file, delimiter=' ', skipinitialspace=True))
bbox = pandas.read_csv(fn("list_bbox_celeba.txt"), delim_whitespace=True, header=1, index_col=0)
landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1) if header is not None:
attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1) headers = data[header]
data = data[header + 1:]
mask = slice(None) if split_ is None else (splits[1] == split_)
indices = [row[0] for row in data]
self.filename = splits[mask].index.values data = [row[1:] for row in data]
self.identity = torch.as_tensor(identity[mask].values) data_int = [list(map(int, i)) for i in data]
self.bbox = torch.as_tensor(bbox[mask].values)
self.landmarks_align = torch.as_tensor(landmarks_align[mask].values) return CSV(headers, indices, torch.tensor(data_int))
self.attr = torch.as_tensor(attr[mask].values)
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
self.attr_names = list(attr.columns)
def _check_integrity(self) -> bool: def _check_integrity(self) -> bool:
for (_, md5, filename) in self.file_list: for (_, md5, filename) in self.file_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment