Unverified Commit ac56f52e authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Deactivate CelebA download (#6052)



* Deactivate CelebA download

* flake8

* Do proto version

* Update torchvision/prototype/datasets/utils/_resource.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* address review
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 37665a0b
...@@ -6,7 +6,7 @@ from typing import Any, Callable, List, Optional, Union, Tuple ...@@ -6,7 +6,7 @@ from typing import Any, Callable, List, Optional, Union, Tuple
import PIL import PIL
import torch import torch
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive from .utils import check_integrity, verify_str_arg
from .vision import VisionDataset from .vision import VisionDataset
CSV = namedtuple("CSV", ["header", "index", "data"]) CSV = namedtuple("CSV", ["header", "index", "data"])
...@@ -35,9 +35,16 @@ class CelebA(VisionDataset): ...@@ -35,9 +35,16 @@ class CelebA(VisionDataset):
and returns a transformed version. E.g, ``transforms.PILToTensor`` and returns a transformed version. E.g, ``transforms.PILToTensor``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and download (bool, optional): Unsupported.
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. .. warning::
Downloading CelebA is not supported anymore as of 0.13. See
`this issue <https://github.com/pytorch/vision/issues/5705>`__
for more details.
Please download the files from
https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract
them in ``root/celeba``.
""" """
base_folder = "celeba" base_folder = "celeba"
...@@ -146,10 +153,13 @@ class CelebA(VisionDataset): ...@@ -146,10 +153,13 @@ class CelebA(VisionDataset):
print("Files already downloaded and verified") print("Files already downloaded and verified")
return return
for (file_id, md5, filename) in self.file_list: raise ValueError(
download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) "Downloading CelebA is not supported anymore as of 0.13. See "
"https://github.com/pytorch/vision/issues/5705 for more details. "
extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip")) "Please download the files from "
"https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them "
"in ``root/celeba``."
)
def __getitem__(self, index: int) -> Tuple[Any, Any]: def __getitem__(self, index: int) -> Tuple[Any, Any]:
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
......
...@@ -11,7 +11,7 @@ from torchdata.datapipes.iter import ( ...@@ -11,7 +11,7 @@ from torchdata.datapipes.iter import (
) )
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
GDriveResource, ManualDownloadResource,
OnlineResource, OnlineResource,
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
...@@ -85,33 +85,34 @@ class CelebA(Dataset): ...@@ -85,33 +85,34 @@ class CelebA(Dataset):
super().__init__(root, skip_integrity_check=skip_integrity_check) super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
splits = GDriveResource( instructions = "Please download the file from https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html."
"0B7EVK8r0v71pY0NSMzRuSXJEVkk", splits = ManualDownloadResource(
instructions=instructions,
sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7", sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7",
file_name="list_eval_partition.txt", file_name="list_eval_partition.txt",
) )
images = GDriveResource( images = ManualDownloadResource(
"0B7EVK8r0v71pZjFTYXZWM3FlRnM", instructions=instructions,
sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74", sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74",
file_name="img_align_celeba.zip", file_name="img_align_celeba.zip",
) )
identities = GDriveResource( identities = ManualDownloadResource(
"1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", instructions=instructions,
sha256="c6143857c3e2630ac2da9f782e9c1232e5e59be993a9d44e8a7916c78a6158c0", sha256="c6143857c3e2630ac2da9f782e9c1232e5e59be993a9d44e8a7916c78a6158c0",
file_name="identity_CelebA.txt", file_name="identity_CelebA.txt",
) )
attributes = GDriveResource( attributes = ManualDownloadResource(
"0B7EVK8r0v71pblRyaVFSWGxPY0U", instructions=instructions,
sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0", sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0",
file_name="list_attr_celeba.txt", file_name="list_attr_celeba.txt",
) )
bounding_boxes = GDriveResource( bounding_boxes = ManualDownloadResource(
"0B7EVK8r0v71pbThiMVRxWXZ4dU0", instructions=instructions,
sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b", sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b",
file_name="list_bbox_celeba.txt", file_name="list_bbox_celeba.txt",
) )
landmarks = GDriveResource( landmarks = ManualDownloadResource(
"0B7EVK8r0v71pd0FJY3Blby1HUTQ", instructions=instructions,
sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b", sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b",
file_name="list_landmarks_align_celeba.txt", file_name="list_landmarks_align_celeba.txt",
) )
......
...@@ -216,9 +216,9 @@ class ManualDownloadResource(OnlineResource): ...@@ -216,9 +216,9 @@ class ManualDownloadResource(OnlineResource):
def _download(self, root: pathlib.Path) -> NoReturn: def _download(self, root: pathlib.Path) -> NoReturn:
raise RuntimeError( raise RuntimeError(
f"The file {self.file_name} cannot be downloaded automatically. " f"The file {self.file_name} was not found, and cannot be downloaded automatically.\n\n"
f"Please follow the instructions below and place it in {root}\n\n" f"{self.instructions.strip()}\n\n"
f"{self.instructions}" f"Once it is downloaded, please place the file in {root}."
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment