Unverified Commit ab60e538 authored by ORippler's avatar ORippler Committed by GitHub
Browse files

Fix download from google drive which was downloading empty files in some cases (#4109)

parent 95966689
...@@ -5,15 +5,15 @@ import hashlib ...@@ -5,15 +5,15 @@ import hashlib
import gzip import gzip
import re import re
import tarfile import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse from urllib.parse import urlparse
import zipfile import zipfile
import lzma import lzma
import contextlib
import urllib import urllib
import urllib.request import urllib.request
import urllib.error import urllib.error
import pathlib import pathlib
import itertools
import torch import torch
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
...@@ -184,11 +184,10 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: ...@@ -184,11 +184,10 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
return files return files
def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined] def _quota_exceeded(first_chunk: bytes) -> bool: # type: ignore[name-defined]
try: try:
start = next(response.iter_content(chunk_size=128, decode_unicode=True)) return "Google Drive - Quota exceeded" in first_chunk.decode()
return isinstance(start, str) and "Google Drive - Quota exceeded" in start except UnicodeDecodeError:
except StopIteration:
return False return False
...@@ -224,7 +223,16 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ ...@@ -224,7 +223,16 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
params = {'id': file_id, 'confirm': token} params = {'id': file_id, 'confirm': token}
response = session.get(url, params=params, stream=True) response = session.get(url, params=params, stream=True)
if _quota_exceeded(response): # Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
# the first_chunk of the payload
response_content_generator = response.iter_content(32768)
first_chunk = None
while not first_chunk: # filter out keep-alive new chunks
first_chunk = next(response_content_generator)
if _quota_exceeded(first_chunk):
msg = ( msg = (
f"The daily quota of the file {filename} is exceeded and it " f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive " f"can't be downloaded. This is a limitation of Google Drive "
...@@ -232,7 +240,8 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ ...@@ -232,7 +240,8 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
) )
raise RuntimeError(msg) raise RuntimeError(msg)
_save_response_content(response, fpath) _save_response_content(itertools.chain((first_chunk, ), response_content_generator), fpath)
response.close()
def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined] def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined]
...@@ -244,12 +253,13 @@ def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: ...@@ -244,12 +253,13 @@ def _get_confirm_token(response: "requests.models.Response") -> Optional[str]:
def _save_response_content( def _save_response_content(
response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined] response_gen: Iterator[bytes], destination: str, # type: ignore[name-defined]
) -> None: ) -> None:
with open(destination, "wb") as f: with open(destination, "wb") as f:
pbar = tqdm(total=None) pbar = tqdm(total=None)
progress = 0 progress = 0
for chunk in response.iter_content(chunk_size):
for chunk in response_gen:
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
f.write(chunk) f.write(chunk)
progress += len(chunk) progress += len(chunk)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment