Unverified Commit 6a936e48 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add gdown as optional requirement for dataset GDrive download (#8237)

parent 4c0f4414
...@@ -36,7 +36,7 @@ jobs: ...@@ -36,7 +36,7 @@ jobs:
run: pip install --no-build-isolation --editable . run: pip install --no-build-isolation --editable .
- name: Install all optional dataset requirements - name: Install all optional dataset requirements
run: pip install scipy pycocotools lmdb requests run: pip install scipy pycocotools lmdb gdown
- name: Install tests requirements - name: Install tests requirements
run: pip install pytest run: pip install pytest
......
...@@ -142,3 +142,7 @@ ignore_missing_imports = True ...@@ -142,3 +142,7 @@ ignore_missing_imports = True
[mypy-h5py.*] [mypy-h5py.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-gdown.*]
ignore_missing_imports = True
...@@ -59,7 +59,6 @@ if os.getenv("PYTORCH_VERSION"): ...@@ -59,7 +59,6 @@ if os.getenv("PYTORCH_VERSION"):
requirements = [ requirements = [
"numpy", "numpy",
"requests",
pytorch_dep, pytorch_dep,
] ]
......
...@@ -30,6 +30,10 @@ class Caltech101(VisionDataset): ...@@ -30,6 +30,10 @@ class Caltech101(VisionDataset):
download (bool, optional): If true, downloads the dataset from the internet and download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not puts it in root directory. If dataset is already downloaded, it is not
downloaded again. downloaded again.
.. warning::
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
""" """
def __init__( def __init__(
......
...@@ -38,6 +38,10 @@ class CelebA(VisionDataset): ...@@ -38,6 +38,10 @@ class CelebA(VisionDataset):
download (bool, optional): If true, downloads the dataset from the internet and download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not puts it in root directory. If dataset is already downloaded, it is not
downloaded again. downloaded again.
.. warning::
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
""" """
base_folder = "celeba" base_folder = "celeba"
......
...@@ -25,6 +25,10 @@ class PCAM(VisionDataset): ...@@ -25,6 +25,10 @@ class PCAM(VisionDataset):
target_transform (callable, optional): A function/transform that takes in the target and transforms it. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
dataset is already downloaded, it is not downloaded again. dataset is already downloaded, it is not downloaded again.
.. warning::
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
""" """
_FILES = { _FILES = {
......
import bz2 import bz2
import contextlib
import gzip import gzip
import hashlib import hashlib
import itertools
import lzma import lzma
import os import os
import os.path import os.path
...@@ -13,13 +11,11 @@ import tarfile ...@@ -13,13 +11,11 @@ import tarfile
import urllib import urllib
import urllib.error import urllib.error
import urllib.request import urllib.request
import warnings
import zipfile import zipfile
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
import requests
import torch import torch
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
...@@ -191,22 +187,6 @@ def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False ...@@ -191,22 +187,6 @@ def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False
return files return files
def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
content = response.iter_content(chunk_size)
first_chunk = None
# filter out keep-alive new chunks
while not first_chunk:
first_chunk = next(content)
content = itertools.chain([first_chunk], content)
try:
match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
api_response = match["api_response"] if match is not None else None
except UnicodeDecodeError:
api_response = None
return api_response, content
def download_file_from_google_drive( def download_file_from_google_drive(
file_id: str, file_id: str,
root: Union[str, pathlib.Path], root: Union[str, pathlib.Path],
...@@ -221,7 +201,12 @@ def download_file_from_google_drive( ...@@ -221,7 +201,12 @@ def download_file_from_google_drive(
filename (str, optional): Name to save the file under. If None, use the id of the file. filename (str, optional): Name to save the file under. If None, use the id of the file.
md5 (str, optional): MD5 checksum of the download. If None, do not check md5 (str, optional): MD5 checksum of the download. If None, do not check
""" """
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url try:
import gdown
except ModuleNotFoundError:
raise RuntimeError(
"To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'."
)
root = os.path.expanduser(root) root = os.path.expanduser(root)
if not filename: if not filename:
...@@ -234,51 +219,10 @@ def download_file_from_google_drive( ...@@ -234,51 +219,10 @@ def download_file_from_google_drive(
print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}") print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
return return
url = "https://drive.google.com/uc" gdown.download(id=file_id, output=fpath, quiet=False, user_agent=USER_AGENT)
params = dict(id=file_id, export="download")
with requests.Session() as session:
response = session.get(url, params=params, stream=True)
for key, value in response.cookies.items(): if not check_integrity(fpath, md5):
if key.startswith("download_warning"): raise RuntimeError("File not found or corrupted.")
token = value
break
else:
api_response, content = _extract_gdrive_api_response(response)
token = "t" if api_response == "Virus scan warning" else None
if token is not None:
response = session.get(url, params=dict(params, confirm=token), stream=True)
api_response, content = _extract_gdrive_api_response(response)
if api_response == "Quota exceeded":
raise RuntimeError(
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
_save_response_content(content, fpath)
# In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
if os.stat(fpath).st_size < 10 * 1024:
with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
text = fh.read()
# Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
if re.search(r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
warnings.warn(
f"We detected some HTML elements in the downloaded file. "
f"This most likely means that the download triggered an unhandled API response by GDrive. "
f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
f"the response:\n\n{text}"
)
if md5 and not check_md5(fpath, md5):
raise RuntimeError(
f"The MD5 checksum of the download file {fpath} does not match the one on record."
f"Please delete the file and try again. "
f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
)
def _extract_tar( def _extract_tar(
......
...@@ -34,6 +34,10 @@ class WIDERFace(VisionDataset): ...@@ -34,6 +34,10 @@ class WIDERFace(VisionDataset):
puts it in root directory. If dataset is already downloaded, it is not puts it in root directory. If dataset is already downloaded, it is not
downloaded again. downloaded again.
.. warning::
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
""" """
BASE_FOLDER = "widerface" BASE_FOLDER = "widerface"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment