Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0e4cc050
Commit
0e4cc050
authored
Oct 24, 2019
by
Sergey Mironov
Browse files
Add support for resumable downloads for HTTP protocol.
parent
0e64fec1
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
87 additions
and
15 deletions
+87
-15
transformers/configuration_auto.py
transformers/configuration_auto.py
+3
-0
transformers/configuration_utils.py
transformers/configuration_utils.py
+6
-1
transformers/file_utils.py
transformers/file_utils.py
+36
-11
transformers/modeling_auto.py
transformers/modeling_auto.py
+8
-0
transformers/modeling_tf_auto.py
transformers/modeling_tf_auto.py
+12
-0
transformers/modeling_tf_utils.py
transformers/modeling_tf_utils.py
+7
-1
transformers/modeling_utils.py
transformers/modeling_utils.py
+7
-1
transformers/tokenization_auto.py
transformers/tokenization_auto.py
+3
-0
transformers/tokenization_utils.py
transformers/tokenization_utils.py
+5
-1
No files found.
transformers/configuration_auto.py
View file @
0e4cc050
...
@@ -92,6 +92,9 @@ class AutoConfig(object):
...
@@ -92,6 +92,9 @@ class AutoConfig(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
...
transformers/configuration_utils.py
View file @
0e4cc050
...
@@ -93,6 +93,9 @@ class PretrainedConfig(object):
...
@@ -93,6 +93,9 @@ class PretrainedConfig(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -119,6 +122,7 @@ class PretrainedConfig(object):
...
@@ -119,6 +122,7 @@ class PretrainedConfig(object):
"""
"""
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
force_download
=
kwargs
.
pop
(
'force_download'
,
False
)
force_download
=
kwargs
.
pop
(
'force_download'
,
False
)
resume_download
=
kwargs
.
pop
(
'resume_download'
,
False
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
return_unused_kwargs
=
kwargs
.
pop
(
'return_unused_kwargs'
,
False
)
return_unused_kwargs
=
kwargs
.
pop
(
'return_unused_kwargs'
,
False
)
...
@@ -130,7 +134,8 @@ class PretrainedConfig(object):
...
@@ -130,7 +134,8 @@ class PretrainedConfig(object):
config_file
=
pretrained_model_name_or_path
config_file
=
pretrained_model_name_or_path
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
)
except
EnvironmentError
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
msg
=
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
msg
=
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
...
...
transformers/file_utils.py
View file @
0e4cc050
...
@@ -22,6 +22,7 @@ from botocore.config import Config
...
@@ -22,6 +22,7 @@ from botocore.config import Config
from
botocore.exceptions
import
ClientError
from
botocore.exceptions
import
ClientError
import
requests
import
requests
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
contextlib
import
contextmanager
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -152,7 +153,7 @@ def filename_to_url(filename, cache_dir=None):
...
@@ -152,7 +153,7 @@ def filename_to_url(filename, cache_dir=None):
return
url
,
etag
return
url
,
etag
def
cached_path
(
url_or_filename
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
):
def
cached_path
(
url_or_filename
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
resume_download
=
False
):
"""
"""
Given something that might be a URL (or might be a local path),
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
determine which. If it's a URL, download the file and cache it, and
...
@@ -161,6 +162,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
...
@@ -161,6 +162,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
Args:
Args:
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found.
"""
"""
if
cache_dir
is
None
:
if
cache_dir
is
None
:
cache_dir
=
TRANSFORMERS_CACHE
cache_dir
=
TRANSFORMERS_CACHE
...
@@ -173,7 +175,9 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
...
@@ -173,7 +175,9 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
if
parsed
.
scheme
in
(
'http'
,
'https'
,
's3'
):
if
parsed
.
scheme
in
(
'http'
,
'https'
,
's3'
):
# URL, so get it from the cache (downloading if necessary)
# URL, so get it from the cache (downloading if necessary)
return
get_from_cache
(
url_or_filename
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
return
get_from_cache
(
url_or_filename
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
)
elif
os
.
path
.
exists
(
url_or_filename
):
elif
os
.
path
.
exists
(
url_or_filename
):
# File, and it exists.
# File, and it exists.
return
url_or_filename
return
url_or_filename
...
@@ -234,19 +238,22 @@ def s3_get(url, temp_file, proxies=None):
...
@@ -234,19 +238,22 @@ def s3_get(url, temp_file, proxies=None):
s3_resource
.
Bucket
(
bucket_name
).
download_fileobj
(
s3_path
,
temp_file
)
s3_resource
.
Bucket
(
bucket_name
).
download_fileobj
(
s3_path
,
temp_file
)
def
http_get
(
url
,
temp_file
,
proxies
=
None
):
def
http_get
(
url
,
temp_file
,
proxies
=
None
,
resume_size
=
0
):
req
=
requests
.
get
(
url
,
stream
=
True
,
proxies
=
proxies
)
headers
=
{
'Range'
:
'bytes=%d-'
%
(
resume_size
,)}
if
resume_size
>
0
else
None
content_length
=
req
.
headers
.
get
(
'Content-Length'
)
response
=
requests
.
get
(
url
,
stream
=
True
,
proxies
=
proxies
,
headers
=
headers
)
total
=
int
(
content_length
)
if
content_length
is
not
None
else
None
if
response
.
status_code
==
416
:
# Range not satisfiable
progress
=
tqdm
(
unit
=
"B"
,
total
=
total
)
return
for
chunk
in
req
.
iter_content
(
chunk_size
=
1024
):
content_length
=
response
.
headers
.
get
(
'Content-Length'
)
total
=
resume_size
+
int
(
content_length
)
if
content_length
is
not
None
else
None
progress
=
tqdm
(
unit
=
"B"
,
total
=
total
,
initial
=
resume_size
)
for
chunk
in
response
.
iter_content
(
chunk_size
=
1024
):
if
chunk
:
# filter out keep-alive new chunks
if
chunk
:
# filter out keep-alive new chunks
progress
.
update
(
len
(
chunk
))
progress
.
update
(
len
(
chunk
))
temp_file
.
write
(
chunk
)
temp_file
.
write
(
chunk
)
progress
.
close
()
progress
.
close
()
def
get_from_cache
(
url
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
etag_timeout
=
10
):
def
get_from_cache
(
url
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
etag_timeout
=
10
,
resume_download
=
False
):
"""
"""
Given a URL, look for the corresponding dataset in the local cache.
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
If it's not there, download it. Then return the path to the cached file.
...
@@ -289,17 +296,35 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
...
@@ -289,17 +296,35 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
if
matching_files
:
if
matching_files
:
cache_path
=
os
.
path
.
join
(
cache_dir
,
matching_files
[
-
1
])
cache_path
=
os
.
path
.
join
(
cache_dir
,
matching_files
[
-
1
])
if
resume_download
:
incomplete_path
=
cache_path
+
'.incomplete'
@
contextmanager
def
_resumable_file_manager
():
with
open
(
incomplete_path
,
'a+b'
)
as
f
:
yield
f
os
.
remove
(
incomplete_path
)
temp_file_manager
=
_resumable_file_manager
if
os
.
path
.
exists
(
incomplete_path
):
resume_size
=
os
.
stat
(
incomplete_path
).
st_size
else
:
resume_size
=
0
else
:
temp_file_manager
=
tempfile
.
NamedTemporaryFile
resume_size
=
0
if
not
os
.
path
.
exists
(
cache_path
)
or
force_download
:
if
not
os
.
path
.
exists
(
cache_path
)
or
force_download
:
# Download to temporary file, then copy to cache dir once finished.
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with
tempfile
.
NamedTemporaryFile
()
as
temp_file
:
with
temp
_
file
_manager
()
as
temp_file
:
logger
.
info
(
"%s not found in cache or force_download set to True, downloading to %s"
,
url
,
temp_file
.
name
)
logger
.
info
(
"%s not found in cache or force_download set to True, downloading to %s"
,
url
,
temp_file
.
name
)
# GET file object
# GET file object
if
url
.
startswith
(
"s3://"
):
if
url
.
startswith
(
"s3://"
):
if
resume_download
:
logger
.
warn
(
'Warning: resumable downloads are not implemented for "s3://" urls'
)
s3_get
(
url
,
temp_file
,
proxies
=
proxies
)
s3_get
(
url
,
temp_file
,
proxies
=
proxies
)
else
:
else
:
http_get
(
url
,
temp_file
,
proxies
=
proxies
)
http_get
(
url
,
temp_file
,
proxies
=
proxies
,
resume_size
=
resume_size
)
# we are copying the file before closing it, so flush to avoid truncation
# we are copying the file before closing it, so flush to avoid truncation
temp_file
.
flush
()
temp_file
.
flush
()
...
...
transformers/modeling_auto.py
View file @
0e4cc050
...
@@ -112,6 +112,9 @@ class AutoModel(object):
...
@@ -112,6 +112,9 @@ class AutoModel(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -237,6 +240,8 @@ class AutoModelWithLMHead(object):
...
@@ -237,6 +240,8 @@ class AutoModelWithLMHead(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
...
@@ -357,6 +362,9 @@ class AutoModelForSequenceClassification(object):
...
@@ -357,6 +362,9 @@ class AutoModelForSequenceClassification(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
...
transformers/modeling_tf_auto.py
View file @
0e4cc050
...
@@ -109,6 +109,9 @@ class TFAutoModel(object):
...
@@ -109,6 +109,9 @@ class TFAutoModel(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -237,6 +240,9 @@ class TFAutoModelWithLMHead(object):
...
@@ -237,6 +240,9 @@ class TFAutoModelWithLMHead(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -360,6 +366,9 @@ class TFAutoModelForSequenceClassification(object):
...
@@ -360,6 +366,9 @@ class TFAutoModelForSequenceClassification(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -472,6 +481,9 @@ class TFAutoModelForQuestionAnswering(object):
...
@@ -472,6 +481,9 @@ class TFAutoModelForQuestionAnswering(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
...
transformers/modeling_tf_utils.py
View file @
0e4cc050
...
@@ -176,6 +176,9 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -176,6 +176,9 @@ class TFPreTrainedModel(tf.keras.Model):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -201,6 +204,7 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -201,6 +204,7 @@ class TFPreTrainedModel(tf.keras.Model):
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
from_pt
=
kwargs
.
pop
(
'from_pt'
,
False
)
from_pt
=
kwargs
.
pop
(
'from_pt'
,
False
)
force_download
=
kwargs
.
pop
(
'force_download'
,
False
)
force_download
=
kwargs
.
pop
(
'force_download'
,
False
)
resume_download
=
kwargs
.
pop
(
'resume_download'
,
False
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
# Load config
# Load config
...
@@ -209,6 +213,7 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -209,6 +213,7 @@ class TFPreTrainedModel(tf.keras.Model):
pretrained_model_name_or_path
,
*
model_args
,
pretrained_model_name_or_path
,
*
model_args
,
cache_dir
=
cache_dir
,
return_unused_kwargs
=
True
,
cache_dir
=
cache_dir
,
return_unused_kwargs
=
True
,
force_download
=
force_download
,
force_download
=
force_download
,
resume_download
=
resume_download
,
**
kwargs
**
kwargs
)
)
else
:
else
:
...
@@ -236,7 +241,8 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -236,7 +241,8 @@ class TFPreTrainedModel(tf.keras.Model):
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resume_download
=
resume_download
,
proxies
=
proxies
)
except
EnvironmentError
as
e
:
except
EnvironmentError
as
e
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
logger
.
error
(
logger
.
error
(
...
...
transformers/modeling_utils.py
View file @
0e4cc050
...
@@ -246,6 +246,9 @@ class PreTrainedModel(nn.Module):
...
@@ -246,6 +246,9 @@ class PreTrainedModel(nn.Module):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -275,6 +278,7 @@ class PreTrainedModel(nn.Module):
...
@@ -275,6 +278,7 @@ class PreTrainedModel(nn.Module):
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
pop
(
'from_tf'
,
False
)
from_tf
=
kwargs
.
pop
(
'from_tf'
,
False
)
force_download
=
kwargs
.
pop
(
'force_download'
,
False
)
force_download
=
kwargs
.
pop
(
'force_download'
,
False
)
resume_download
=
kwargs
.
pop
(
'resume_download'
,
False
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
output_loading_info
=
kwargs
.
pop
(
'output_loading_info'
,
False
)
output_loading_info
=
kwargs
.
pop
(
'output_loading_info'
,
False
)
...
@@ -284,6 +288,7 @@ class PreTrainedModel(nn.Module):
...
@@ -284,6 +288,7 @@ class PreTrainedModel(nn.Module):
pretrained_model_name_or_path
,
*
model_args
,
pretrained_model_name_or_path
,
*
model_args
,
cache_dir
=
cache_dir
,
return_unused_kwargs
=
True
,
cache_dir
=
cache_dir
,
return_unused_kwargs
=
True
,
force_download
=
force_download
,
force_download
=
force_download
,
resume_download
=
resume_download
,
**
kwargs
**
kwargs
)
)
else
:
else
:
...
@@ -315,7 +320,8 @@ class PreTrainedModel(nn.Module):
...
@@ -315,7 +320,8 @@ class PreTrainedModel(nn.Module):
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
)
except
EnvironmentError
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
msg
=
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
msg
=
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
...
...
transformers/tokenization_auto.py
View file @
0e4cc050
...
@@ -87,6 +87,9 @@ class AutoTokenizer(object):
...
@@ -87,6 +87,9 @@ class AutoTokenizer(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the vocabulary files and override the cached versions if they exists.
Force to (re-)download the vocabulary files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
...
transformers/tokenization_utils.py
View file @
0e4cc050
...
@@ -251,6 +251,9 @@ class PreTrainedTokenizer(object):
...
@@ -251,6 +251,9 @@ class PreTrainedTokenizer(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the vocabulary files and override the cached versions if they exists.
Force to (re-)download the vocabulary files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -286,6 +289,7 @@ class PreTrainedTokenizer(object):
...
@@ -286,6 +289,7 @@ class PreTrainedTokenizer(object):
def
_from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
init_inputs
,
**
kwargs
):
def
_from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
init_inputs
,
**
kwargs
):
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
force_download
=
kwargs
.
pop
(
'force_download'
,
False
)
force_download
=
kwargs
.
pop
(
'force_download'
,
False
)
resume_download
=
kwargs
.
pop
(
'resume_download'
,
False
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
s3_models
=
list
(
cls
.
max_model_input_sizes
.
keys
())
s3_models
=
list
(
cls
.
max_model_input_sizes
.
keys
())
...
@@ -352,7 +356,7 @@ class PreTrainedTokenizer(object):
...
@@ -352,7 +356,7 @@ class PreTrainedTokenizer(object):
if
file_path
is
None
:
if
file_path
is
None
:
resolved_vocab_files
[
file_id
]
=
None
resolved_vocab_files
[
file_id
]
=
None
else
:
else
:
resolved_vocab_files
[
file_id
]
=
cached_path
(
file_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_vocab_files
[
file_id
]
=
cached_path
(
file_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
)
except
EnvironmentError
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
s3_models
:
if
pretrained_model_name_or_path
in
s3_models
:
msg
=
"Couldn't reach server at '{}' to download vocabulary files."
msg
=
"Couldn't reach server at '{}' to download vocabulary files."
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment