"vscode:/vscode.git/clone" did not exist on "8cfd4afa92a1cd5f1e2f4c49640c5fc572fb50f1"
Unverified Commit 8d25de7b authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

replace requests with urllib (#4973)

parent 999ef255
......@@ -183,7 +183,7 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
return files
def _quota_exceeded(first_chunk: bytes) -> bool: # type: ignore[name-defined]
def _quota_exceeded(first_chunk: bytes) -> bool:
try:
return "Google Drive - Quota exceeded" in first_chunk.decode()
except UnicodeDecodeError:
......@@ -199,38 +199,28 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
filename (str, optional): Name to save the file under. If None, use the id of the file.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
import requests
url = "https://docs.google.com/uc?export=download"
url = f"https://drive.google.com/uc?export=download&id={file_id}"
root = os.path.expanduser(root)
if not filename:
filename = file_id
fpath = os.path.join(root, filename)
os.makedirs(root, exist_ok=True)
if os.path.isfile(fpath) and check_integrity(fpath, md5):
print("Using downloaded and verified file: " + fpath)
else:
session = requests.Session()
response = session.get(url, params={"id": file_id}, stream=True)
token = _get_confirm_token(response)
return
if token:
params = {"id": file_id, "confirm": token}
response = session.get(url, params=params, stream=True)
os.makedirs(root, exist_ok=True)
with urllib.request.urlopen(url) as response:
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
# the first_chunk of the payload
response_content_generator = response.iter_content(32768)
content = iter(lambda: response.read(32768), b"")
first_chunk = None
while not first_chunk: # filter out keep-alive new chunks
first_chunk = next(response_content_generator)
first_chunk = next(content)
if _quota_exceeded(first_chunk):
msg = (
......@@ -240,8 +230,7 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
)
raise RuntimeError(msg)
_save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath)
response.close()
_save_response_content(itertools.chain((first_chunk,), content), fpath)
def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment