Unverified Commit ab60e538 authored by ORippler's avatar ORippler Committed by GitHub
Browse files

Fix download from google drive which was downloading empty files in some cases (#4109)

parent 95966689
......@@ -5,15 +5,15 @@ import hashlib
import gzip
import re
import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse
import zipfile
import lzma
import contextlib
import urllib
import urllib.request
import urllib.error
import pathlib
import itertools
import torch
from torch.utils.model_zoo import tqdm
......@@ -184,11 +184,10 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
return files
def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined]
def _quota_exceeded(first_chunk: bytes) -> bool: # type: ignore[name-defined]
try:
start = next(response.iter_content(chunk_size=128, decode_unicode=True))
return isinstance(start, str) and "Google Drive - Quota exceeded" in start
except StopIteration:
return "Google Drive - Quota exceeded" in first_chunk.decode()
except UnicodeDecodeError:
return False
......@@ -224,7 +223,16 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
params = {'id': file_id, 'confirm': token}
response = session.get(url, params=params, stream=True)
if _quota_exceeded(response):
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
# the first_chunk of the payload
response_content_generator = response.iter_content(32768)
first_chunk = None
while not first_chunk: # filter out keep-alive new chunks
first_chunk = next(response_content_generator)
if _quota_exceeded(first_chunk):
msg = (
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
......@@ -232,7 +240,8 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
)
raise RuntimeError(msg)
_save_response_content(response, fpath)
_save_response_content(itertools.chain((first_chunk, ), response_content_generator), fpath)
response.close()
def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined]
......@@ -244,12 +253,13 @@ def _get_confirm_token(response: "requests.models.Response") -> Optional[str]:
def _save_response_content(
response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined]
response_gen: Iterator[bytes], destination: str, # type: ignore[name-defined]
) -> None:
with open(destination, "wb") as f:
pbar = tqdm(total=None)
progress = 0
for chunk in response.iter_content(chunk_size):
for chunk in response_gen:
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress += len(chunk)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment