Unverified Commit d2486f6c authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Put back CelebA download (#6147)

* Revert "Indicate Celeba download parameter is deprecated and will be removed (#6059)"

This reverts commit 49496c4f.

* Revert "Deactivate CelebA download (#6052)"

This reverts commit ac56f52e.
parent efc67ea9
import csv import csv
import os import os
import warnings
from collections import namedtuple from collections import namedtuple
from typing import Any, Callable, List, Optional, Union, Tuple from typing import Any, Callable, List, Optional, Union, Tuple
import PIL import PIL
import torch import torch
from .utils import check_integrity, verify_str_arg from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive
from .vision import VisionDataset from .vision import VisionDataset
CSV = namedtuple("CSV", ["header", "index", "data"]) CSV = namedtuple("CSV", ["header", "index", "data"])
...@@ -36,17 +35,9 @@ class CelebA(VisionDataset): ...@@ -36,17 +35,9 @@ class CelebA(VisionDataset):
and returns a transformed version. E.g, ``transforms.PILToTensor`` and returns a transformed version. E.g, ``transforms.PILToTensor``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
download (bool, optional): Deprecated. download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
.. warning:: downloaded again.
Downloading CelebA is not supported anymore as of 0.13 and this
parameter will be removed in 0.15. See
`this issue <https://github.com/pytorch/vision/issues/5705>`__
for more details.
Please download the files from
https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract
them in ``root/celeba``.
""" """
base_folder = "celeba" base_folder = "celeba"
...@@ -73,7 +64,7 @@ class CelebA(VisionDataset): ...@@ -73,7 +64,7 @@ class CelebA(VisionDataset):
target_type: Union[List[str], str] = "attr", target_type: Union[List[str], str] = "attr",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = None, download: bool = False,
) -> None: ) -> None:
super().__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self.split = split self.split = split
...@@ -85,15 +76,6 @@ class CelebA(VisionDataset): ...@@ -85,15 +76,6 @@ class CelebA(VisionDataset):
if not self.target_type and self.target_transform is not None: if not self.target_type and self.target_transform is not None:
raise RuntimeError("target_transform is specified but target_type is empty") raise RuntimeError("target_transform is specified but target_type is empty")
if download is not None:
warnings.warn(
"Downloading CelebA is not supported anymore as of 0.13, and the "
"download parameter will be removed in 0.15. See "
"https://github.com/pytorch/vision/issues/5705 for more details. "
"Please download the files from "
"https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them "
"in ``root/celeba``."
)
if download: if download:
self.download() self.download()
...@@ -164,14 +146,10 @@ class CelebA(VisionDataset): ...@@ -164,14 +146,10 @@ class CelebA(VisionDataset):
print("Files already downloaded and verified") print("Files already downloaded and verified")
return return
raise ValueError( for (file_id, md5, filename) in self.file_list:
"Downloading CelebA is not supported anymore as of 0.13, and the " download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
"download parameter will be removed in 0.15. See "
"https://github.com/pytorch/vision/issues/5705 for more details. " extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
"Please download the files from "
"https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them "
"in ``root/celeba``."
)
def __getitem__(self, index: int) -> Tuple[Any, Any]: def __getitem__(self, index: int) -> Tuple[Any, Any]:
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
......
...@@ -11,7 +11,7 @@ from torchdata.datapipes.iter import ( ...@@ -11,7 +11,7 @@ from torchdata.datapipes.iter import (
) )
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
ManualDownloadResource, GDriveResource,
OnlineResource, OnlineResource,
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
...@@ -85,34 +85,33 @@ class CelebA(Dataset): ...@@ -85,34 +85,33 @@ class CelebA(Dataset):
super().__init__(root, skip_integrity_check=skip_integrity_check) super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
instructions = "Please download the file from https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html." splits = GDriveResource(
splits = ManualDownloadResource( "0B7EVK8r0v71pY0NSMzRuSXJEVkk",
instructions=instructions,
sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7", sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7",
file_name="list_eval_partition.txt", file_name="list_eval_partition.txt",
) )
images = ManualDownloadResource( images = GDriveResource(
instructions=instructions, "0B7EVK8r0v71pZjFTYXZWM3FlRnM",
sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74", sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74",
file_name="img_align_celeba.zip", file_name="img_align_celeba.zip",
) )
identities = ManualDownloadResource( identities = GDriveResource(
instructions=instructions, "1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS",
sha256="c6143857c3e2630ac2da9f782e9c1232e5e59be993a9d44e8a7916c78a6158c0", sha256="c6143857c3e2630ac2da9f782e9c1232e5e59be993a9d44e8a7916c78a6158c0",
file_name="identity_CelebA.txt", file_name="identity_CelebA.txt",
) )
attributes = ManualDownloadResource( attributes = GDriveResource(
instructions=instructions, "0B7EVK8r0v71pblRyaVFSWGxPY0U",
sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0", sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0",
file_name="list_attr_celeba.txt", file_name="list_attr_celeba.txt",
) )
bounding_boxes = ManualDownloadResource( bounding_boxes = GDriveResource(
instructions=instructions, "0B7EVK8r0v71pbThiMVRxWXZ4dU0",
sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b", sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b",
file_name="list_bbox_celeba.txt", file_name="list_bbox_celeba.txt",
) )
landmarks = ManualDownloadResource( landmarks = GDriveResource(
instructions=instructions, "0B7EVK8r0v71pd0FJY3Blby1HUTQ",
sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b", sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b",
file_name="list_landmarks_align_celeba.txt", file_name="list_landmarks_align_celeba.txt",
) )
......
...@@ -216,9 +216,9 @@ class ManualDownloadResource(OnlineResource): ...@@ -216,9 +216,9 @@ class ManualDownloadResource(OnlineResource):
def _download(self, root: pathlib.Path) -> NoReturn: def _download(self, root: pathlib.Path) -> NoReturn:
raise RuntimeError( raise RuntimeError(
f"The file {self.file_name} was not found, and cannot be downloaded automatically.\n\n" f"The file {self.file_name} cannot be downloaded automatically. "
f"{self.instructions.strip()}\n\n" f"Please follow the instructions below and place it in {root}\n\n"
f"Once it is downloaded, please place the file in {root}." f"{self.instructions}"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment