Unverified Commit 7d415473 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add custom user agent for download_url (#3498)

* add custom user agent for download_url

* fix progress bar

* lint

* [test] use repo instead of nightly for download tests
parent 89edfaaa
...@@ -26,10 +26,11 @@ jobs: ...@@ -26,10 +26,11 @@ jobs:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v2 uses: actions/checkout@v2
- name: Install PyTorch from the nightlies - name: Install torch nightly build
run: | run: pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install numpy
pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - name: Install torchvision
run: pip install -e .
- name: Install all optional dataset requirements - name: Install all optional dataset requirements
run: pip install scipy pandas pycocotools lmdb requests run: pip install scipy pandas pycocotools lmdb requests
......
...@@ -14,7 +14,7 @@ import warnings ...@@ -14,7 +14,7 @@ import warnings
import pytest import pytest
from torchvision import datasets from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive, USER_AGENT
from common_utils import get_tmp_dir from common_utils import get_tmp_dir
from fakedata_generation import places365_root from fakedata_generation import places365_root
...@@ -150,7 +150,7 @@ def assert_server_response_ok(): ...@@ -150,7 +150,7 @@ def assert_server_response_ok():
def assert_url_is_accessible(url, timeout=5.0): def assert_url_is_accessible(url, timeout=5.0):
request = Request(url, headers=dict(method="HEAD")) request = Request(url, headers={"method": "HEAD", "User-Agent": USER_AGENT})
with assert_server_response_ok(): with assert_server_response_ok():
urlopen(request, timeout=timeout) urlopen(request, timeout=timeout)
...@@ -160,7 +160,8 @@ def assert_file_downloads_correctly(url, md5, timeout=5.0): ...@@ -160,7 +160,8 @@ def assert_file_downloads_correctly(url, md5, timeout=5.0):
file = path.join(root, path.basename(url)) file = path.join(root, path.basename(url))
with assert_server_response_ok(): with assert_server_response_ok():
with open(file, "wb") as fh: with open(file, "wb") as fh:
response = urlopen(url, timeout=timeout) request = Request(url, headers={"User-Agent": USER_AGENT})
response = urlopen(request, timeout=timeout)
fh.write(response.read()) fh.write(response.read())
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch" assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
......
...@@ -7,11 +7,28 @@ import tarfile ...@@ -7,11 +7,28 @@ import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar from typing import Any, Callable, List, Iterable, Optional, TypeVar
from urllib.parse import urlparse from urllib.parse import urlparse
import zipfile import zipfile
import urllib
import urllib.request
import urllib.error
import torch import torch
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
USER_AGENT = "pytorch/vision"
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
with open(filename, "wb") as fh:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
with tqdm(total=response.length) as pbar:
for chunk in iter(lambda: response.read(chunk_size), ""):
if not chunk:
break
pbar.update(chunk_size)
fh.write(chunk)
def gen_bar_updater() -> Callable[[int, int, int], None]: def gen_bar_updater() -> Callable[[int, int, int], None]:
pbar = tqdm(total=None) pbar = tqdm(total=None)
...@@ -83,8 +100,6 @@ def download_url( ...@@ -83,8 +100,6 @@ def download_url(
md5 (str, optional): MD5 checksum of the download. If None, do not check md5 (str, optional): MD5 checksum of the download. If None, do not check
max_redirect_hops (int, optional): Maximum number of redirect hops allowed max_redirect_hops (int, optional): Maximum number of redirect hops allowed
""" """
import urllib
root = os.path.expanduser(root) root = os.path.expanduser(root)
if not filename: if not filename:
filename = os.path.basename(url) filename = os.path.basename(url)
...@@ -108,19 +123,13 @@ def download_url( ...@@ -108,19 +123,13 @@ def download_url(
# download the file # download the file
try: try:
print('Downloading ' + url + ' to ' + fpath) print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve( _urlretrieve(url, fpath)
url, fpath,
reporthook=gen_bar_updater()
)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == 'https': if url[:5] == 'https':
url = url.replace('https:', 'http:') url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.' print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath) ' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve( _urlretrieve(url, fpath)
url, fpath,
reporthook=gen_bar_updater()
)
else: else:
raise e raise e
# check integrity of downloaded file # check integrity of downloaded file
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment