Unverified Commit 96f2c0d4 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

support confirming no virus scan on GDrive download (#5645)



* support confirming no virus scan on GDrive download

* put gen_bar_updater back

* Update torchvision/datasets/utils.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent b7c59a08
...@@ -11,6 +11,7 @@ import tarfile ...@@ -11,6 +11,7 @@ import tarfile
import urllib import urllib
import urllib.error import urllib.error
import urllib.request import urllib.request
import warnings
import zipfile import zipfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse from urllib.parse import urlparse
...@@ -24,22 +25,31 @@ from .._internally_replaced_utils import ( ...@@ -24,22 +25,31 @@ from .._internally_replaced_utils import (
_is_remote_location_available, _is_remote_location_available,
) )
USER_AGENT = "pytorch/vision" USER_AGENT = "pytorch/vision"
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: def _save_response_content(
with open(filename, "wb") as fh: content: Iterator[bytes],
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: destination: str,
with tqdm(total=response.length) as pbar: length: Optional[int] = None,
for chunk in iter(lambda: response.read(chunk_size), ""): ) -> None:
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
for chunk in content:
# filter out keep-alive new chunks
if not chunk: if not chunk:
break continue
pbar.update(chunk_size)
fh.write(chunk) fh.write(chunk)
pbar.update(len(chunk))
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
def gen_bar_updater() -> Callable[[int, int, int], None]: def gen_bar_updater() -> Callable[[int, int, int], None]:
warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.")
pbar = tqdm(total=None) pbar = tqdm(total=None)
def bar_update(count, block_size, total_size): def bar_update(count, block_size, total_size):
...@@ -184,11 +194,20 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: ...@@ -184,11 +194,20 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
return files return files
def _quota_exceeded(first_chunk: bytes) -> bool: def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
content = response.iter_content(chunk_size)
first_chunk = None
# filter out keep-alive new chunks
while not first_chunk:
first_chunk = next(content)
content = itertools.chain([first_chunk], content)
try: try:
return "Google Drive - Quota exceeded" in first_chunk.decode() match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
api_response = match["api_response"] if match is not None else None
except UnicodeDecodeError: except UnicodeDecodeError:
return False api_response = None
return api_response, content
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
...@@ -202,8 +221,6 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ ...@@ -202,8 +221,6 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
""" """
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
url = "https://docs.google.com/uc?export=download"
root = os.path.expanduser(root) root = os.path.expanduser(root)
if not filename: if not filename:
filename = file_id filename = file_id
...@@ -211,61 +228,34 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ ...@@ -211,61 +228,34 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
if os.path.isfile(fpath) and check_integrity(fpath, md5): if check_integrity(fpath, md5):
print("Using downloaded and verified file: " + fpath) print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
else:
session = requests.Session()
response = session.get(url, params={"id": file_id}, stream=True)
token = _get_confirm_token(response)
if token: url = "https://drive.google.com/uc"
params = {"id": file_id, "confirm": token} params = dict(id=file_id, export="download")
with requests.Session() as session:
response = session.get(url, params=params, stream=True) response = session.get(url, params=params, stream=True)
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent for key, value in response.cookies.items():
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517. if key.startswith("download_warning"):
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding token = value
# the first_chunk of the payload break
response_content_generator = response.iter_content(32768) else:
first_chunk = None api_response, content = _extract_gdrive_api_response(response)
while not first_chunk: # filter out keep-alive new chunks token = "t" if api_response == "Virus scan warning" else None
first_chunk = next(response_content_generator)
if token is not None:
response = session.get(url, params=dict(params, confirm=token), stream=True)
api_response, content = _extract_gdrive_api_response(response)
if _quota_exceeded(first_chunk): if api_response == "Quota exceeded":
msg = ( raise RuntimeError(
f"The daily quota of the file {filename} is exceeded and it " f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive " f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later." f"and can only be overcome by trying again later."
) )
raise RuntimeError(msg)
_save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath)
response.close()
def _get_confirm_token(response: requests.models.Response) -> Optional[str]: _save_response_content(content, fpath)
for key, value in response.cookies.items():
if key.startswith("download_warning"):
return value
return None
def _save_response_content(
response_gen: Iterator[bytes],
destination: str,
) -> None:
with open(destination, "wb") as f:
pbar = tqdm(total=None)
progress = 0
for chunk in response_gen:
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress += len(chunk)
pbar.update(progress - pbar.n)
pbar.close()
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment