Unverified Commit 7d415473 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add custom user agent for download_url (#3498)

* add custom user agent for download_url

* fix progress bar

* lint

* [test] use repo instead of nightly for download tests
parent 89edfaaa
......@@ -26,10 +26,11 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v2
- name: Install PyTorch from the nightlies
run: |
pip install numpy
pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- name: Install torch nightly build
run: pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- name: Install torchvision
run: pip install -e .
- name: Install all optional dataset requirements
run: pip install scipy pandas pycocotools lmdb requests
......
......@@ -14,7 +14,7 @@ import warnings
import pytest
from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive, USER_AGENT
from common_utils import get_tmp_dir
from fakedata_generation import places365_root
......@@ -150,7 +150,7 @@ def assert_server_response_ok():
def assert_url_is_accessible(url, timeout=5.0):
request = Request(url, headers=dict(method="HEAD"))
request = Request(url, headers={"method": "HEAD", "User-Agent": USER_AGENT})
with assert_server_response_ok():
urlopen(request, timeout=timeout)
......@@ -160,7 +160,8 @@ def assert_file_downloads_correctly(url, md5, timeout=5.0):
file = path.join(root, path.basename(url))
with assert_server_response_ok():
with open(file, "wb") as fh:
response = urlopen(url, timeout=timeout)
request = Request(url, headers={"User-Agent": USER_AGENT})
response = urlopen(request, timeout=timeout)
fh.write(response.read())
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
......
......@@ -7,11 +7,28 @@ import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar
from urllib.parse import urlparse
import zipfile
import urllib
import urllib.request
import urllib.error
import torch
from torch.utils.model_zoo import tqdm
USER_AGENT = "pytorch/vision"
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
with open(filename, "wb") as fh:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
with tqdm(total=response.length) as pbar:
for chunk in iter(lambda: response.read(chunk_size), ""):
if not chunk:
break
pbar.update(chunk_size)
fh.write(chunk)
def gen_bar_updater() -> Callable[[int, int, int], None]:
pbar = tqdm(total=None)
......@@ -83,8 +100,6 @@ def download_url(
md5 (str, optional): MD5 checksum of the download. If None, do not check
max_redirect_hops (int, optional): Maximum number of redirect hops allowed
"""
import urllib
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
......@@ -108,19 +123,13 @@ def download_url(
# download the file
try:
print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
_urlretrieve(url, fpath)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
_urlretrieve(url, fpath)
else:
raise e
# check integrity of downloaded file
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment