Unverified Commit 0985533e authored by Sofya Lipnitskaya's avatar Sofya Lipnitskaya Committed by GitHub
Browse files

Make download_url() follow redirects (#3235) (#3236)



* Make download_url() follow redirects

Fix bug related to the incorrect processing of redirects.
Follow the redirect chain until the destination is reached or
the number of redirects exceeds the max allowed value (by default 10).

* Parametrize value of max allowed redirect number

Make max number of hops a function argument and assign its default value to 10

* Propagate the max number of hops to download_url()

Add the maximum number of redirect hops parameter to download_url()

* check file existence before redirect

* remove print

* remove recursion

* add tests

* Reducing max_redirect_hops
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarVasilis Vryniotis <vvryniotis@fb.com>
parent 3b19d6fc
...@@ -36,6 +36,18 @@ class Tester(unittest.TestCase): ...@@ -36,6 +36,18 @@ class Tester(unittest.TestCase):
self.assertTrue(utils.check_integrity(existing_fpath)) self.assertTrue(utils.check_integrity(existing_fpath))
self.assertFalse(utils.check_integrity(nonexisting_fpath)) self.assertFalse(utils.check_integrity(nonexisting_fpath))
def test_get_redirect_url(self):
url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
expected = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
actual = utils._get_redirect_url(url)
assert actual == expected
def test_get_redirect_url_max_hops_exceeded(self):
url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
with self.assertRaises(RecursionError):
utils._get_redirect_url(url, max_hops=0)
def test_download_url(self): def test_download_url(self):
with get_tmp_dir() as temp_dir: with get_tmp_dir() as temp_dir:
url = "http://github.com/pytorch/vision/archive/master.zip" url = "http://github.com/pytorch/vision/archive/master.zip"
......
...@@ -42,7 +42,23 @@ def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: ...@@ -42,7 +42,23 @@ def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
return check_md5(fpath, md5) return check_md5(fpath, md5)
def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None: def _get_redirect_url(url: str, max_hops: int = 10) -> str:
import requests
for hop in range(max_hops + 1):
response = requests.get(url)
if response.url == url or response.url is None:
return url
url = response.url
else:
raise RecursionError(f"Too many redirects: {max_hops + 1})")
def download_url(
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
) -> None:
"""Download a file from a url and place it in root. """Download a file from a url and place it in root.
Args: Args:
...@@ -50,6 +66,7 @@ def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optio ...@@ -50,6 +66,7 @@ def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optio
root (str): Directory to place downloaded file in root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the basename of the URL filename (str, optional): Name to save the file under. If None, use the basename of the URL
md5 (str, optional): MD5 checksum of the download. If None, do not check md5 (str, optional): MD5 checksum of the download. If None, do not check
max_redirect_hops (int, optional): Maximum number of redirect hops allowed
""" """
import urllib import urllib
...@@ -63,27 +80,32 @@ def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optio ...@@ -63,27 +80,32 @@ def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optio
# check if file is already present locally # check if file is already present locally
if check_integrity(fpath, md5): if check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath) print('Using downloaded and verified file: ' + fpath)
else: # download the file return
try:
print('Downloading ' + url + ' to ' + fpath) # expand redirect chain if needed
url = _get_redirect_url(url, max_hops=max_redirect_hops)
# download the file
try:
print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve( urllib.request.urlretrieve(
url, fpath, url, fpath,
reporthook=gen_bar_updater() reporthook=gen_bar_updater()
) )
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] else:
if url[:5] == 'https': raise e
url = url.replace('https:', 'http:') # check integrity of downloaded file
print('Failed download. Trying https -> http instead.' if not check_integrity(fpath, md5):
' Downloading ' + url + ' to ' + fpath) raise RuntimeError("File not found or corrupted.")
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")
def list_dir(root: str, prefix: bool = False) -> List[str]: def list_dir(root: str, prefix: bool = False) -> List[str]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment