"docs/vscode:/vscode.git/clone" did not exist on "a600b30cc35465326ac11e2b4d26865ea555d08b"
Unverified Commit ff36e6d8 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2231 from huggingface/requests_user_agent

[http] customizable requests user-agent
parents a5a06a85 15d897ff
......@@ -23,6 +23,7 @@ from botocore.exceptions import ClientError
import requests
from tqdm.auto import tqdm
from contextlib import contextmanager
from . import __version__
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
......@@ -77,6 +78,7 @@ DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net"
def is_torch_available():
......@@ -114,11 +116,12 @@ def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ('http', 'https', 's3')
def hf_bucket_url(identifier, postfix=None):
def hf_bucket_url(identifier, postfix=None, cdn=False):
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
if postfix is None:
return "/".join((S3_BUCKET_PREFIX, identifier))
return "/".join((endpoint, identifier))
else:
return "/".join((S3_BUCKET_PREFIX, identifier, postfix))
return "/".join((endpoint, identifier, postfix))
def url_to_filename(url, etag=None):
......@@ -126,7 +129,7 @@ def url_to_filename(url, etag=None):
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
If the url ends with .h5 (Keras HDF5 weights) ands '.h5' to the name
If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
so that TF 2.0 can identify it as a HDF5 file
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
"""
......@@ -171,7 +174,7 @@ def filename_to_url(filename, cache_dir=None):
return url, etag
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False):
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
......@@ -181,6 +184,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
......@@ -193,7 +197,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir=cache_dir,
force_download=force_download, proxies=proxies,
resume_download=resume_download)
resume_download=resume_download, user_agent=user_agent)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
......@@ -254,8 +258,19 @@ def s3_get(url, temp_file, proxies=None):
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file, proxies=None, resume_size=0):
headers={'Range':'bytes=%d-'%(resume_size,)} if resume_size > 0 else None
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if isinstance(user_agent, dict):
ua += "; " + "; ".join(
"{}/{}".format(k, v) for k, v in user_agent.items()
)
elif isinstance(user_agent, six.string_types):
ua += "; "+ user_agent
headers = {
"user-agent": ua
}
if resume_size > 0:
headers['Range'] = 'bytes=%d-' % (resume_size,)
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
if response.status_code == 416: # Range not satisfiable
return
......@@ -269,7 +284,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0):
progress.close()
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False):
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
......@@ -340,7 +355,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies)
else:
http_get(url, temp_file, proxies=proxies, resume_size=resume_size)
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment