Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ff36e6d8
Unverified
Commit
ff36e6d8
authored
Dec 20, 2019
by
Thomas Wolf
Committed by
GitHub
Dec 20, 2019
Browse files
Merge pull request #2231 from huggingface/requests_user_agent
[http] customizable requests user-agent
parents
a5a06a85
15d897ff
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
10 deletions
+25
-10
transformers/file_utils.py
transformers/file_utils.py
+25
-10
No files found.
transformers/file_utils.py
View file @
ff36e6d8
...
@@ -23,6 +23,7 @@ from botocore.exceptions import ClientError
...
@@ -23,6 +23,7 @@ from botocore.exceptions import ClientError
import
requests
import
requests
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
.
import
__version__
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -77,6 +78,7 @@ DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
...
@@ -77,6 +78,7 @@ DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
DUMMY_MASK
=
[[
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
]]
DUMMY_MASK
=
[[
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
]]
S3_BUCKET_PREFIX
=
"https://s3.amazonaws.com/models.huggingface.co/bert"
S3_BUCKET_PREFIX
=
"https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX
=
"https://d2ws9o8vfrpkyk.cloudfront.net"
def
is_torch_available
():
def
is_torch_available
():
...
@@ -114,11 +116,12 @@ def is_remote_url(url_or_filename):
...
@@ -114,11 +116,12 @@ def is_remote_url(url_or_filename):
parsed
=
urlparse
(
url_or_filename
)
parsed
=
urlparse
(
url_or_filename
)
return
parsed
.
scheme
in
(
'http'
,
'https'
,
's3'
)
return
parsed
.
scheme
in
(
'http'
,
'https'
,
's3'
)
def
hf_bucket_url
(
identifier
,
postfix
=
None
):
def
hf_bucket_url
(
identifier
,
postfix
=
None
,
cdn
=
False
):
endpoint
=
CLOUDFRONT_DISTRIB_PREFIX
if
cdn
else
S3_BUCKET_PREFIX
if
postfix
is
None
:
if
postfix
is
None
:
return
"/"
.
join
((
S3_BUCKET_PREFIX
,
identifier
))
return
"/"
.
join
((
endpoint
,
identifier
))
else
:
else
:
return
"/"
.
join
((
S3_BUCKET_PREFIX
,
identifier
,
postfix
))
return
"/"
.
join
((
endpoint
,
identifier
,
postfix
))
def
url_to_filename
(
url
,
etag
=
None
):
def
url_to_filename
(
url
,
etag
=
None
):
...
@@ -126,7 +129,7 @@ def url_to_filename(url, etag=None):
...
@@ -126,7 +129,7 @@ def url_to_filename(url, etag=None):
Convert `url` into a hashed filename in a repeatable way.
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
If `etag` is specified, append its hash to the url's, delimited
by a period.
by a period.
If the url ends with .h5 (Keras HDF5 weights) a
n
ds '.h5' to the name
If the url ends with .h5 (Keras HDF5 weights) a
d
ds '.h5' to the name
so that TF 2.0 can identify it as a HDF5 file
so that TF 2.0 can identify it as a HDF5 file
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
"""
"""
...
@@ -171,7 +174,7 @@ def filename_to_url(filename, cache_dir=None):
...
@@ -171,7 +174,7 @@ def filename_to_url(filename, cache_dir=None):
return
url
,
etag
return
url
,
etag
def
cached_path
(
url_or_filename
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
resume_download
=
False
):
def
cached_path
(
url_or_filename
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
resume_download
=
False
,
user_agent
=
None
):
"""
"""
Given something that might be a URL (or might be a local path),
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
determine which. If it's a URL, download the file and cache it, and
...
@@ -181,6 +184,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
...
@@ -181,6 +184,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found.
resume_download: if True, resume the download if incompletly recieved file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
"""
"""
if
cache_dir
is
None
:
if
cache_dir
is
None
:
cache_dir
=
TRANSFORMERS_CACHE
cache_dir
=
TRANSFORMERS_CACHE
...
@@ -193,7 +197,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
...
@@ -193,7 +197,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
# URL, so get it from the cache (downloading if necessary)
# URL, so get it from the cache (downloading if necessary)
return
get_from_cache
(
url_or_filename
,
cache_dir
=
cache_dir
,
return
get_from_cache
(
url_or_filename
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
)
resume_download
=
resume_download
,
user_agent
=
user_agent
)
elif
os
.
path
.
exists
(
url_or_filename
):
elif
os
.
path
.
exists
(
url_or_filename
):
# File, and it exists.
# File, and it exists.
return
url_or_filename
return
url_or_filename
...
@@ -254,8 +258,19 @@ def s3_get(url, temp_file, proxies=None):
...
@@ -254,8 +258,19 @@ def s3_get(url, temp_file, proxies=None):
s3_resource
.
Bucket
(
bucket_name
).
download_fileobj
(
s3_path
,
temp_file
)
s3_resource
.
Bucket
(
bucket_name
).
download_fileobj
(
s3_path
,
temp_file
)
def
http_get
(
url
,
temp_file
,
proxies
=
None
,
resume_size
=
0
):
def
http_get
(
url
,
temp_file
,
proxies
=
None
,
resume_size
=
0
,
user_agent
=
None
):
headers
=
{
'Range'
:
'bytes=%d-'
%
(
resume_size
,)}
if
resume_size
>
0
else
None
ua
=
"transformers/{}; python/{}"
.
format
(
__version__
,
sys
.
version
.
split
()[
0
])
if
isinstance
(
user_agent
,
dict
):
ua
+=
"; "
+
"; "
.
join
(
"{}/{}"
.
format
(
k
,
v
)
for
k
,
v
in
user_agent
.
items
()
)
elif
isinstance
(
user_agent
,
six
.
string_types
):
ua
+=
"; "
+
user_agent
headers
=
{
"user-agent"
:
ua
}
if
resume_size
>
0
:
headers
[
'Range'
]
=
'bytes=%d-'
%
(
resume_size
,)
response
=
requests
.
get
(
url
,
stream
=
True
,
proxies
=
proxies
,
headers
=
headers
)
response
=
requests
.
get
(
url
,
stream
=
True
,
proxies
=
proxies
,
headers
=
headers
)
if
response
.
status_code
==
416
:
# Range not satisfiable
if
response
.
status_code
==
416
:
# Range not satisfiable
return
return
...
@@ -269,7 +284,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0):
...
@@ -269,7 +284,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0):
progress
.
close
()
progress
.
close
()
def
get_from_cache
(
url
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
etag_timeout
=
10
,
resume_download
=
False
):
def
get_from_cache
(
url
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
etag_timeout
=
10
,
resume_download
=
False
,
user_agent
=
None
):
"""
"""
Given a URL, look for the corresponding dataset in the local cache.
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
If it's not there, download it. Then return the path to the cached file.
...
@@ -340,7 +355,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
...
@@ -340,7 +355,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
logger
.
warn
(
'Warning: resumable downloads are not implemented for "s3://" urls'
)
logger
.
warn
(
'Warning: resumable downloads are not implemented for "s3://" urls'
)
s3_get
(
url
,
temp_file
,
proxies
=
proxies
)
s3_get
(
url
,
temp_file
,
proxies
=
proxies
)
else
:
else
:
http_get
(
url
,
temp_file
,
proxies
=
proxies
,
resume_size
=
resume_size
)
http_get
(
url
,
temp_file
,
proxies
=
proxies
,
resume_size
=
resume_size
,
user_agent
=
user_agent
)
# we are copying the file before closing it, so flush to avoid truncation
# we are copying the file before closing it, so flush to avoid truncation
temp_file
.
flush
()
temp_file
.
flush
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment