Commit 265550ec authored by thomwolf's avatar thomwolf
Browse files

relax network connection requirements

parent fa765202
...@@ -5,11 +5,13 @@ Copyright by the AllenNLP authors. ...@@ -5,11 +5,13 @@ Copyright by the AllenNLP authors.
""" """
from __future__ import (absolute_import, division, print_function, unicode_literals) from __future__ import (absolute_import, division, print_function, unicode_literals)
import sys
import json import json
import logging import logging
import os import os
import shutil import shutil
import tempfile import tempfile
import fnmatch
from functools import wraps from functools import wraps
from hashlib import sha256 from hashlib import sha256
import sys import sys
...@@ -191,17 +193,30 @@ def get_from_cache(url, cache_dir=None): ...@@ -191,17 +193,30 @@ def get_from_cache(url, cache_dir=None):
if url.startswith("s3://"): if url.startswith("s3://"):
etag = s3_etag(url) etag = s3_etag(url)
else: else:
try:
response = requests.head(url, allow_redirects=True) response = requests.head(url, allow_redirects=True)
if response.status_code != 200: if response.status_code != 200:
raise IOError("HEAD request failed for url {} with status code {}" etag = None
.format(url, response.status_code)) else:
etag = response.headers.get("ETag") etag = response.headers.get("ETag")
except EnvironmentError:
etag = None
if sys.version_info[0] == 2 and etag is not None:
etag = etag.decode('utf-8')
filename = url_to_filename(url, etag) filename = url_to_filename(url, etag)
# get cache path to put the file # get cache path to put the file
cache_path = os.path.join(cache_dir, filename) cache_path = os.path.join(cache_dir, filename)
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished. # Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted. # Otherwise you get corrupt cache entries if the download gets interrupted.
...@@ -226,8 +241,8 @@ def get_from_cache(url, cache_dir=None): ...@@ -226,8 +241,8 @@ def get_from_cache(url, cache_dir=None):
logger.info("creating metadata file for %s", cache_path) logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag} meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json' meta_path = cache_path + '.json'
with open(meta_path, 'w', encoding="utf-8") as meta_file: with open(meta_path, 'w') as meta_file:
json.dump(meta, meta_file) meta_file.write(json.dumps(meta, indent=4))
logger.info("removing temp file %s", temp_file.name) logger.info("removing temp file %s", temp_file.name)
......
...@@ -66,7 +66,7 @@ class GPT2TokenizationTest(unittest.TestCase): ...@@ -66,7 +66,7 @@ class GPT2TokenizationTest(unittest.TestCase):
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks, [tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder]) tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
@pytest.mark.slow # @pytest.mark.slow
def test_tokenizer_from_pretrained(self): def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/" cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment