Commit 15d897ff authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[http] customizable requests user-agent

parent f25e9b6f
......@@ -23,6 +23,7 @@ from botocore.exceptions import ClientError
import requests
from tqdm.auto import tqdm
from contextlib import contextmanager
from . import __version__
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
......@@ -173,7 +174,7 @@ def filename_to_url(filename, cache_dir=None):
return url, etag
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False):
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
......@@ -183,6 +184,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
......@@ -195,7 +197,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir=cache_dir,
force_download=force_download, proxies=proxies,
resume_download=resume_download)
resume_download=resume_download, user_agent=user_agent)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
......@@ -256,8 +258,19 @@ def s3_get(url, temp_file, proxies=None):
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file, proxies=None, resume_size=0):
headers={'Range':'bytes=%d-'%(resume_size,)} if resume_size > 0 else None
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if isinstance(user_agent, dict):
ua += "; " + "; ".join(
"{}/{}".format(k, v) for k, v in user_agent.items()
)
elif isinstance(user_agent, six.string_types):
ua += "; "+ user_agent
headers = {
"user-agent": ua
}
if resume_size > 0:
headers['Range'] = 'bytes=%d-' % (resume_size,)
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
if response.status_code == 416: # Range not satisfiable
return
......@@ -271,7 +284,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0):
progress.close()
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False):
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
......@@ -342,7 +355,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies)
else:
http_get(url, temp_file, proxies=proxies, resume_size=resume_size)
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment