import os import os.path import hashlib import errno from torch.utils.model_zoo import tqdm def gen_bar_updater(): pbar = tqdm(total=None) def bar_update(count, block_size, total_size): if pbar.total is None and total_size: pbar.total = total_size progress_bytes = count * block_size pbar.update(progress_bytes - pbar.n) return bar_update def calculate_md5(fpath, chunk_size=1024 * 1024): md5 = hashlib.md5() with open(fpath, 'rb') as f: for chunk in iter(lambda: f.read(chunk_size), b''): md5.update(chunk) return md5.hexdigest() def check_md5(fpath, md5, **kwargs): return md5 == calculate_md5(fpath, **kwargs) def check_integrity(fpath, md5=None): if not os.path.isfile(fpath): return False if md5 is None: return True else: return check_md5(fpath, md5) def makedir_exist_ok(dirpath): """ Python2 support for os.makedirs(.., exist_ok=True) """ try: os.makedirs(dirpath) except OSError as e: if e.errno == errno.EEXIST: pass else: raise def download_url(url, root, filename=None, md5=None): """Download a file from a url and place it in root. Args: url (str): URL to download file from root (str): Directory to place downloaded file in filename (str, optional): Name to save the file under. If None, use the basename of the URL md5 (str, optional): MD5 checksum of the download. If None, do not check """ from six.moves import urllib root = os.path.expanduser(root) if not filename: filename = os.path.basename(url) fpath = os.path.join(root, filename) makedir_exist_ok(root) # downloads file if os.path.isfile(fpath) and check_integrity(fpath, md5): print('Using downloaded and verified file: ' + fpath) else: try: print('Downloading ' + url + ' to ' + fpath) urllib.request.urlretrieve( url, fpath, reporthook=gen_bar_updater() ) except OSError: if url[:5] == 'https': url = url.replace('https:', 'http:') print('Failed download. Trying https -> http instead.' ' Downloading ' + url + ' to ' + fpath) urllib.request.urlretrieve( url, fpath, reporthook=gen_bar_updater() ) def list_dir(root, prefix=False): """List all directories at a given root Args: root (str): Path to directory whose folders need to be listed prefix (bool, optional): If true, prepends the path to each result, otherwise only returns the name of the directories found """ root = os.path.expanduser(root) directories = list( filter( lambda p: os.path.isdir(os.path.join(root, p)), os.listdir(root) ) ) if prefix is True: directories = [os.path.join(root, d) for d in directories] return directories def list_files(root, suffix, prefix=False): """List all files ending with a suffix at a given root Args: root (str): Path to directory whose folders need to be listed suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). It uses the Python "str.endswith" method and is passed directly prefix (bool, optional): If true, prepends the path to each result, otherwise only returns the name of the files found """ root = os.path.expanduser(root) files = list( filter( lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), os.listdir(root) ) ) if prefix is True: files = [os.path.join(root, d) for d in files] return files def download_file_from_google_drive(file_id, root, filename=None, md5=None): """Download a Google Drive file from and place it in root. Args: file_id (str): id of file to be downloaded root (str): Directory to place downloaded file in filename (str, optional): Name to save the file under. If None, use the id of the file. md5 (str, optional): MD5 checksum of the download. If None, do not check """ # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url import requests url = "https://docs.google.com/uc?export=download" root = os.path.expanduser(root) if not filename: filename = file_id fpath = os.path.join(root, filename) makedir_exist_ok(root) if os.path.isfile(fpath) and check_integrity(fpath, md5): print('Using downloaded and verified file: ' + fpath) else: session = requests.Session() response = session.get(url, params={'id': file_id}, stream=True) token = _get_confirm_token(response) if token: params = {'id': file_id, 'confirm': token} response = session.get(url, params=params, stream=True) _save_response_content(response, fpath) def _get_confirm_token(response): for key, value in response.cookies.items(): if key.startswith('download_warning'): return value return None def _save_response_content(response, destination, chunk_size=32768): with open(destination, "wb") as f: pbar = tqdm(total=None) progress = 0 for chunk in response.iter_content(chunk_size): if chunk: # filter out keep-alive new chunks f.write(chunk) progress += len(chunk) pbar.update(progress - pbar.n) pbar.close()