"cacheflow/git@developer.sourcefind.cn:norm/vllm.git" did not exist on "762fd1c3faf17001b943cb68e5d08e7d7fe59119"
Unverified Commit 96f2c0d4 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

support confirming no virus scan on GDrive download (#5645)



* support confirming no virus scan on GDrive download

* put gen_bar_updater back

* Update torchvision/datasets/utils.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent b7c59a08
......@@ -11,6 +11,7 @@ import tarfile
import urllib
import urllib.error
import urllib.request
import warnings
import zipfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse
......@@ -24,22 +25,31 @@ from .._internally_replaced_utils import (
_is_remote_location_available,
)
USER_AGENT = "pytorch/vision"
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
with open(filename, "wb") as fh:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
with tqdm(total=response.length) as pbar:
for chunk in iter(lambda: response.read(chunk_size), ""):
if not chunk:
break
pbar.update(chunk_size)
fh.write(chunk)
def _save_response_content(
content: Iterator[bytes],
destination: str,
length: Optional[int] = None,
) -> None:
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
for chunk in content:
# filter out keep-alive new chunks
if not chunk:
continue
fh.write(chunk)
pbar.update(len(chunk))
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
def gen_bar_updater() -> Callable[[int, int, int], None]:
warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.")
pbar = tqdm(total=None)
def bar_update(count, block_size, total_size):
......@@ -184,11 +194,20 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
return files
def _quota_exceeded(first_chunk: bytes) -> bool:
def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
content = response.iter_content(chunk_size)
first_chunk = None
# filter out keep-alive new chunks
while not first_chunk:
first_chunk = next(content)
content = itertools.chain([first_chunk], content)
try:
return "Google Drive - Quota exceeded" in first_chunk.decode()
match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
api_response = match["api_response"] if match is not None else None
except UnicodeDecodeError:
return False
api_response = None
return api_response, content
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
......@@ -202,8 +221,6 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
url = "https://docs.google.com/uc?export=download"
root = os.path.expanduser(root)
if not filename:
filename = file_id
......@@ -211,61 +228,34 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
os.makedirs(root, exist_ok=True)
if os.path.isfile(fpath) and check_integrity(fpath, md5):
print("Using downloaded and verified file: " + fpath)
else:
session = requests.Session()
response = session.get(url, params={"id": file_id}, stream=True)
token = _get_confirm_token(response)
if token:
params = {"id": file_id, "confirm": token}
response = session.get(url, params=params, stream=True)
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
# the first_chunk of the payload
response_content_generator = response.iter_content(32768)
first_chunk = None
while not first_chunk: # filter out keep-alive new chunks
first_chunk = next(response_content_generator)
if _quota_exceeded(first_chunk):
msg = (
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
raise RuntimeError(msg)
_save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath)
response.close()
if check_integrity(fpath, md5):
print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
url = "https://drive.google.com/uc"
params = dict(id=file_id, export="download")
with requests.Session() as session:
response = session.get(url, params=params, stream=True)
def _get_confirm_token(response: requests.models.Response) -> Optional[str]:
for key, value in response.cookies.items():
if key.startswith("download_warning"):
return value
for key, value in response.cookies.items():
if key.startswith("download_warning"):
token = value
break
else:
api_response, content = _extract_gdrive_api_response(response)
token = "t" if api_response == "Virus scan warning" else None
return None
if token is not None:
response = session.get(url, params=dict(params, confirm=token), stream=True)
api_response, content = _extract_gdrive_api_response(response)
if api_response == "Quota exceeded":
raise RuntimeError(
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
def _save_response_content(
response_gen: Iterator[bytes],
destination: str,
) -> None:
with open(destination, "wb") as f:
pbar = tqdm(total=None)
progress = 0
for chunk in response_gen:
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress += len(chunk)
pbar.update(progress - pbar.n)
pbar.close()
_save_response_content(content, fpath)
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment