Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
265550ec
Commit
265550ec
authored
Apr 17, 2019
by
thomwolf
Browse files
relax network connection requirements
parent
fa765202
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
8 deletions
+23
-8
pytorch_pretrained_bert/file_utils.py
pytorch_pretrained_bert/file_utils.py
+22
-7
tests/tokenization_gpt2_test.py
tests/tokenization_gpt2_test.py
+1
-1
No files found.
pytorch_pretrained_bert/file_utils.py
View file @
265550ec
...
...
@@ -5,11 +5,13 @@ Copyright by the AllenNLP authors.
"""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
sys
import
json
import
logging
import
os
import
shutil
import
tempfile
import
fnmatch
from
functools
import
wraps
from
hashlib
import
sha256
import
sys
...
...
@@ -191,17 +193,30 @@ def get_from_cache(url, cache_dir=None):
if
url
.
startswith
(
"s3://"
):
etag
=
s3_etag
(
url
)
else
:
try
:
response
=
requests
.
head
(
url
,
allow_redirects
=
True
)
if
response
.
status_code
!=
200
:
raise
IOError
(
"HEAD request failed for url {} with status code {}"
.
format
(
url
,
response
.
status_code
))
etag
=
None
else
:
etag
=
response
.
headers
.
get
(
"ETag"
)
except
EnvironmentError
:
etag
=
None
if
sys
.
version_info
[
0
]
==
2
and
etag
is
not
None
:
etag
=
etag
.
decode
(
'utf-8'
)
filename
=
url_to_filename
(
url
,
etag
)
# get cache path to put the file
cache_path
=
os
.
path
.
join
(
cache_dir
,
filename
)
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if
not
os
.
path
.
exists
(
cache_path
)
and
etag
is
None
:
matching_files
=
fnmatch
.
filter
(
os
.
listdir
(
cache_dir
),
filename
+
'.*'
)
matching_files
=
list
(
filter
(
lambda
s
:
not
s
.
endswith
(
'.json'
),
matching_files
))
if
matching_files
:
cache_path
=
os
.
path
.
join
(
cache_dir
,
matching_files
[
-
1
])
if
not
os
.
path
.
exists
(
cache_path
):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
...
...
@@ -226,8 +241,8 @@ def get_from_cache(url, cache_dir=None):
logger
.
info
(
"creating metadata file for %s"
,
cache_path
)
meta
=
{
'url'
:
url
,
'etag'
:
etag
}
meta_path
=
cache_path
+
'.json'
with
open
(
meta_path
,
'w'
,
encoding
=
"utf-8"
)
as
meta_file
:
json
.
dump
(
meta
,
meta_file
)
with
open
(
meta_path
,
'w'
)
as
meta_file
:
meta_file
.
write
(
json
.
dump
s
(
meta
,
indent
=
4
)
)
logger
.
info
(
"removing temp file %s"
,
temp_file
.
name
)
...
...
tests/tokenization_gpt2_test.py
View file @
265550ec
...
...
@@ -66,7 +66,7 @@ class GPT2TokenizationTest(unittest.TestCase):
[
tokenizer_2
.
encoder
,
tokenizer_2
.
decoder
,
tokenizer_2
.
bpe_ranks
,
tokenizer_2
.
special_tokens
,
tokenizer_2
.
special_tokens_decoder
])
@
pytest
.
mark
.
slow
#
@pytest.mark.slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
())[:
1
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment