Unverified Commit 05dcf50a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

use helper function to extract archive in CelebA (#4557)


Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent 32df801a
...@@ -6,7 +6,7 @@ from typing import Any, Callable, List, Optional, Union, Tuple ...@@ -6,7 +6,7 @@ from typing import Any, Callable, List, Optional, Union, Tuple
import PIL import PIL
import torch import torch
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive
from .vision import VisionDataset from .vision import VisionDataset
CSV = namedtuple("CSV", ["header", "index", "data"]) CSV = namedtuple("CSV", ["header", "index", "data"])
...@@ -142,8 +142,6 @@ class CelebA(VisionDataset): ...@@ -142,8 +142,6 @@ class CelebA(VisionDataset):
return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
def download(self) -> None: def download(self) -> None:
import zipfile
if self._check_integrity(): if self._check_integrity():
print("Files already downloaded and verified") print("Files already downloaded and verified")
return return
...@@ -151,8 +149,7 @@ class CelebA(VisionDataset): ...@@ -151,8 +149,7 @@ class CelebA(VisionDataset):
for (file_id, md5, filename) in self.file_list: for (file_id, md5, filename) in self.file_list:
download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
f.extractall(os.path.join(self.root, self.base_folder))
def __getitem__(self, index: int) -> Tuple[Any, Any]: def __getitem__(self, index: int) -> Tuple[Any, Any]:
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment