"tests/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "2708fac6c5ec017d384a4f41fa74d91c9e79b47d"
Unverified Commit 96f2c0d4 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

support confirming no virus scan on GDrive download (#5645)



* support confirming no virus scan on GDrive download

* put gen_bar_updater back

* Update torchvision/datasets/utils.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent b7c59a08
...@@ -11,6 +11,7 @@ import tarfile ...@@ -11,6 +11,7 @@ import tarfile
import urllib import urllib
import urllib.error import urllib.error
import urllib.request import urllib.request
import warnings
import zipfile import zipfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse from urllib.parse import urlparse
...@@ -24,22 +25,31 @@ from .._internally_replaced_utils import ( ...@@ -24,22 +25,31 @@ from .._internally_replaced_utils import (
_is_remote_location_available, _is_remote_location_available,
) )
USER_AGENT = "pytorch/vision" USER_AGENT = "pytorch/vision"
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: def _save_response_content(
with open(filename, "wb") as fh: content: Iterator[bytes],
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: destination: str,
with tqdm(total=response.length) as pbar: length: Optional[int] = None,
for chunk in iter(lambda: response.read(chunk_size), ""): ) -> None:
if not chunk: with open(destination, "wb") as fh, tqdm(total=length) as pbar:
break for chunk in content:
pbar.update(chunk_size) # filter out keep-alive new chunks
fh.write(chunk) if not chunk:
continue
fh.write(chunk)
pbar.update(len(chunk))
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
def gen_bar_updater() -> Callable[[int, int, int], None]: def gen_bar_updater() -> Callable[[int, int, int], None]:
warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.")
pbar = tqdm(total=None) pbar = tqdm(total=None)
def bar_update(count, block_size, total_size): def bar_update(count, block_size, total_size):
...@@ -184,11 +194,20 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: ...@@ -184,11 +194,20 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
return files return files
def _quota_exceeded(first_chunk: bytes) -> bool: def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
content = response.iter_content(chunk_size)
first_chunk = None
# filter out keep-alive new chunks
while not first_chunk:
first_chunk = next(content)
content = itertools.chain([first_chunk], content)
try: try:
return "Google Drive - Quota exceeded" in first_chunk.decode() match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
api_response = match["api_response"] if match is not None else None
except UnicodeDecodeError: except UnicodeDecodeError:
return False api_response = None
return api_response, content
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
...@@ -202,8 +221,6 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ ...@@ -202,8 +221,6 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
""" """
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
url = "https://docs.google.com/uc?export=download"
root = os.path.expanduser(root) root = os.path.expanduser(root)
if not filename: if not filename:
filename = file_id filename = file_id
...@@ -211,61 +228,34 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ ...@@ -211,61 +228,34 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
if os.path.isfile(fpath) and check_integrity(fpath, md5): if check_integrity(fpath, md5):
print("Using downloaded and verified file: " + fpath) print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
else:
session = requests.Session()
response = session.get(url, params={"id": file_id}, stream=True)
token = _get_confirm_token(response)
if token:
params = {"id": file_id, "confirm": token}
response = session.get(url, params=params, stream=True)
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
# the first_chunk of the payload
response_content_generator = response.iter_content(32768)
first_chunk = None
while not first_chunk: # filter out keep-alive new chunks
first_chunk = next(response_content_generator)
if _quota_exceeded(first_chunk):
msg = (
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
raise RuntimeError(msg)
_save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath)
response.close()
url = "https://drive.google.com/uc"
params = dict(id=file_id, export="download")
with requests.Session() as session:
response = session.get(url, params=params, stream=True)
def _get_confirm_token(response: requests.models.Response) -> Optional[str]: for key, value in response.cookies.items():
for key, value in response.cookies.items(): if key.startswith("download_warning"):
if key.startswith("download_warning"): token = value
return value break
else:
api_response, content = _extract_gdrive_api_response(response)
token = "t" if api_response == "Virus scan warning" else None
return None if token is not None:
response = session.get(url, params=dict(params, confirm=token), stream=True)
api_response, content = _extract_gdrive_api_response(response)
if api_response == "Quota exceeded":
raise RuntimeError(
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
def _save_response_content( _save_response_content(content, fpath)
response_gen: Iterator[bytes],
destination: str,
) -> None:
with open(destination, "wb") as f:
pbar = tqdm(total=None)
progress = 0
for chunk in response_gen:
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress += len(chunk)
pbar.update(progress - pbar.n)
pbar.close()
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment