Unverified Commit e0407b51 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Resume download (#320)

* resume download, validate with md5 or sha256.

* split stream from saving. detect filename.

* validate at end too. check file size again.

* validate now operates on file object.

* expose choices of hash.
parent 3dcf812b
import csv
import errno
import gzip
import hashlib
import logging
import os
......@@ -12,7 +11,7 @@ from queue import Queue
import six
import torch
import torchaudio
from six.moves import urllib
from torch.utils.data import Dataset
from torch.utils.model_zoo import tqdm
......@@ -53,18 +52,6 @@ def unicode_csv_reader(unicode_csv_data, **kwargs):
yield line
def gen_bar_updater():
pbar = tqdm(total=None)
def bar_update(count, block_size, total_size):
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)
return bar_update
def makedir_exist_ok(dirpath):
"""
Python2 support for os.makedirs(.., exist_ok=True)
......@@ -78,41 +65,130 @@ def makedir_exist_ok(dirpath):
raise
def download_url(url, root, filename=None, md5=None):
"""Download a file from a url and place it in root.
def stream_url(url, start_byte=None, block_size=32 * 1024, progress_bar=True):
"""Stream url by chunk
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the basename of the URL
md5 (str, optional): MD5 checksum of the download. If None, do not check
url (str): Url.
start_byte (Optional[int]): Start streaming at that point.
block_size (int): Size of chunks to stream.
progress_bar (bool): Display a progress bar.
"""
from six.moves import urllib
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
# If we already have the whole file, there is no need to download it again
req = urllib.request.Request(url, method="HEAD")
url_size = int(urllib.request.urlopen(req).info().get("Content-Length", -1))
if url_size == start_byte:
return
req = urllib.request.Request(url)
if start_byte:
req.headers["Range"] = "bytes={}-".format(start_byte)
with urllib.request.urlopen(req) as upointer, tqdm(
unit="B",
unit_scale=True,
unit_divisor=1024,
total=url_size,
disable=not progress_bar,
) as pbar:
num_bytes = 0
while True:
chunk = upointer.read(block_size)
if not chunk:
break
yield chunk
num_bytes += len(chunk)
pbar.update(len(chunk))
def download_url(
url,
download_folder,
filename=None,
hash_value=None,
hash_type="sha256",
progress_bar=True,
resume=False,
):
"""Download file to disk.
makedir_exist_ok(root)
Args:
url (str): Url.
download_folder (str): Folder to download file.
filename (str): Name of downloaded file. If None, it is inferred from the url.
hash_value (str): Hash for url.
hash_type (str): Hash type, among "sha256" and "md5".
progress_bar (bool): Display a progress bar.
resume (bool): Enable resuming download.
"""
req = urllib.request.Request(url, method="HEAD")
req_info = urllib.request.urlopen(req).info()
# Detect filename
filename = filename or req_info.get_filename() or os.path.basename(url)
filepath = os.path.join(download_folder, filename)
# downloads file
if os.path.isfile(fpath):
print("Using downloaded file: " + fpath)
if resume and os.path.exists(filepath):
mode = "ab"
local_size = os.path.getsize(filepath)
elif not resume and os.path.exists(filepath):
raise RuntimeError(
"{} already exists. Delete the file manually and retry.".format(filepath)
)
else:
try:
print("Downloading " + url + " to " + fpath)
urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater())
except (urllib.error.URLError, IOError) as e:
if url[:5] == "https":
url = url.replace("https:", "http:")
print(
"Failed download. Trying https -> http instead."
" Downloading " + url + " to " + fpath
mode = "wb"
local_size = None
if hash_value and local_size == int(req_info.get("Content-Length", -1)):
with open(filepath, "rb") as file_obj:
if validate_file(file_obj, hash_value, hash_type):
return
raise RuntimeError(
"The hash of {} does not match. Delete the file manually and retry.".format(
filepath
)
)
with open(filepath, mode) as fpointer:
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
fpointer.write(chunk)
with open(filepath, "rb") as file_obj:
if hash_value and not validate_file(file_obj, hash_value, hash_type):
raise RuntimeError(
"The hash of {} does not match. Delete the file manually and retry.".format(
filepath
)
urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater())
else:
raise e
)
def validate_file(file_obj, hash_value, hash_type="sha256"):
"""Validate a given file object with its hash.
Args:
file_obj: File object to read from.
hash_value (str): Hash for url.
hash_type (str): Hash type, among "sha256" and "md5".
"""
if hash_type == "sha256":
hash_func = hashlib.sha256()
elif hash_type == "md5":
hash_func = hashlib.md5()
else:
raise ValueError
while True:
# Read by chunk to avoid filling memory
chunk = f.read(1024 ** 2)
if not chunk:
break
hash_func.update(chunk)
return hash_func.hexdigest() == hash_value
def extract_archive(from_path, to_path=None, overwrite=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment