Commit 15d897ff authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[http] customizable requests user-agent

parent f25e9b6f
...@@ -23,6 +23,7 @@ from botocore.exceptions import ClientError ...@@ -23,6 +23,7 @@ from botocore.exceptions import ClientError
import requests import requests
from tqdm.auto import tqdm from tqdm.auto import tqdm
from contextlib import contextmanager from contextlib import contextmanager
from . import __version__
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
...@@ -173,7 +174,7 @@ def filename_to_url(filename, cache_dir=None): ...@@ -173,7 +174,7 @@ def filename_to_url(filename, cache_dir=None):
return url, etag return url, etag
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False): def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None):
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and determine which. If it's a URL, download the file and cache it, and
...@@ -183,6 +184,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N ...@@ -183,6 +184,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir. force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found. resume_download: if True, resume the download if incompletly recieved file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
...@@ -195,7 +197,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N ...@@ -195,7 +197,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
# URL, so get it from the cache (downloading if necessary) # URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir=cache_dir, return get_from_cache(url_or_filename, cache_dir=cache_dir,
force_download=force_download, proxies=proxies, force_download=force_download, proxies=proxies,
resume_download=resume_download) resume_download=resume_download, user_agent=user_agent)
elif os.path.exists(url_or_filename): elif os.path.exists(url_or_filename):
# File, and it exists. # File, and it exists.
return url_or_filename return url_or_filename
...@@ -256,8 +258,19 @@ def s3_get(url, temp_file, proxies=None): ...@@ -256,8 +258,19 @@ def s3_get(url, temp_file, proxies=None):
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file, proxies=None, resume_size=0): def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
headers={'Range':'bytes=%d-'%(resume_size,)} if resume_size > 0 else None ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if isinstance(user_agent, dict):
ua += "; " + "; ".join(
"{}/{}".format(k, v) for k, v in user_agent.items()
)
elif isinstance(user_agent, six.string_types):
ua += "; "+ user_agent
headers = {
"user-agent": ua
}
if resume_size > 0:
headers['Range'] = 'bytes=%d-' % (resume_size,)
response = requests.get(url, stream=True, proxies=proxies, headers=headers) response = requests.get(url, stream=True, proxies=proxies, headers=headers)
if response.status_code == 416: # Range not satisfiable if response.status_code == 416: # Range not satisfiable
return return
...@@ -271,7 +284,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0): ...@@ -271,7 +284,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0):
progress.close() progress.close()
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False): def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None):
""" """
Given a URL, look for the corresponding dataset in the local cache. Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
...@@ -342,7 +355,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -342,7 +355,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls') logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies) s3_get(url, temp_file, proxies=proxies)
else: else:
http_get(url, temp_file, proxies=proxies, resume_size=resume_size) http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
# we are copying the file before closing it, so flush to avoid truncation # we are copying the file before closing it, so flush to avoid truncation
temp_file.flush() temp_file.flush()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment