Commit 23a2cea8 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Tokenizer.from_pretrained: fetch all possible files remotely

parent 99f9243d
...@@ -200,6 +200,8 @@ class PretrainedConfig(object): ...@@ -200,6 +200,8 @@ class PretrainedConfig(object):
resume_download=resume_download, resume_download=resume_download,
) )
# Load config dict # Load config dict
if resolved_config_file is None:
raise EnvironmentError
config_dict = cls._dict_from_json_file(resolved_config_file) config_dict = cls._dict_from_json_file(resolved_config_file)
except EnvironmentError: except EnvironmentError:
...@@ -210,7 +212,7 @@ class PretrainedConfig(object): ...@@ -210,7 +212,7 @@ class PretrainedConfig(object):
else: else:
msg = ( msg = (
"Model name '{}' was not found in model name list. " "Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url to a configuration file named {} or " "We assumed '{}' was a path, a model identifier, or url to a configuration file named {} or "
"a directory containing such a file but couldn't find any such file at this path or url.".format( "a directory containing such a file but couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path, config_file, CONFIG_NAME, pretrained_model_name_or_path, config_file, CONFIG_NAME,
) )
......
...@@ -14,6 +14,7 @@ import tempfile ...@@ -14,6 +14,7 @@ import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial, wraps from functools import partial, wraps
from hashlib import sha256 from hashlib import sha256
from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import boto3 import boto3
...@@ -122,7 +123,7 @@ def is_remote_url(url_or_filename): ...@@ -122,7 +123,7 @@ def is_remote_url(url_or_filename):
return parsed.scheme in ("http", "https", "s3") return parsed.scheme in ("http", "https", "s3")
def hf_bucket_url(identifier, postfix=None, cdn=False): def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
if postfix is None: if postfix is None:
return "/".join((endpoint, identifier)) return "/".join((endpoint, identifier))
...@@ -182,7 +183,7 @@ def filename_to_url(filename, cache_dir=None): ...@@ -182,7 +183,7 @@ def filename_to_url(filename, cache_dir=None):
def cached_path( def cached_path(
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
): ) -> Optional[str]:
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and determine which. If it's a URL, download the file and cache it, and
...@@ -193,6 +194,10 @@ def cached_path( ...@@ -193,6 +194,10 @@ def cached_path(
force_download: if True, re-dowload the file even if it's already cached in the cache dir. force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found. resume_download: if True, resume the download if incompletly recieved file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests. user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
Local path (string) otherwise
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
...@@ -306,10 +311,14 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None): ...@@ -306,10 +311,14 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
def get_from_cache( def get_from_cache(
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
): ) -> Optional[str]:
""" """
Given a URL, look for the corresponding dataset in the local cache. Given a URL, look for the corresponding file in the local cache.
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
Local path (string) otherwise
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
...@@ -336,16 +345,25 @@ def get_from_cache( ...@@ -336,16 +345,25 @@ def get_from_cache(
# get cache path to put the file # get cache path to put the file
cache_path = os.path.join(cache_dir, filename) cache_path = os.path.join(cache_dir, filename)
# If we don't have a connection (etag is None) and can't identify the file # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
# try to get the last downloaded one # try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None: if etag is None:
matching_files = [ if os.path.exists(cache_path):
file return cache_path
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") else:
if not file.endswith(".json") and not file.endswith(".lock") matching_files = [
] file
if matching_files: for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
cache_path = os.path.join(cache_dir, matching_files[-1]) if not file.endswith(".json") and not file.endswith(".lock")
]
if len(matching_files) > 0:
return os.path.join(cache_dir, matching_files[-1])
else:
return None
# From now on, etag is not None.
if os.path.exists(cache_path) and not force_download:
return cache_path
# Prevent parallel downloads of the same file with a lock. # Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + ".lock" lock_path = cache_path + ".lock"
...@@ -368,29 +386,26 @@ def get_from_cache( ...@@ -368,29 +386,26 @@ def get_from_cache(
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
resume_size = 0 resume_size = 0
if etag is not None and (not os.path.exists(cache_path) or force_download): # Download to temporary file, then copy to cache dir once finished.
# Download to temporary file, then copy to cache dir once finished. # Otherwise you get corrupt cache entries if the download gets interrupted.
# Otherwise you get corrupt cache entries if the download gets interrupted. with temp_file_manager() as temp_file:
with temp_file_manager() as temp_file: logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
logger.info(
"%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name # GET file object
) if url.startswith("s3://"):
if resume_download:
# GET file object logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
if url.startswith("s3://"): s3_get(url, temp_file, proxies=proxies)
if resume_download: else:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls') http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
s3_get(url, temp_file, proxies=proxies)
else: logger.info("storing %s in cache at %s", url, cache_path)
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) os.rename(temp_file.name, cache_path)
logger.info("storing %s in cache at %s", url, cache_path) logger.info("creating metadata file for %s", cache_path)
os.rename(temp_file.name, cache_path) meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
logger.info("creating metadata file for %s", cache_path) with open(meta_path, "w") as meta_file:
meta = {"url": url, "etag": etag} json.dump(meta, meta_file)
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
json.dump(meta, meta_file)
return cache_path return cache_path
...@@ -264,7 +264,7 @@ class PreTrainedTokenizer(object): ...@@ -264,7 +264,7 @@ class PreTrainedTokenizer(object):
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``. - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
- (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. - (not applicable to all derived classes, deprecated) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
cache_dir: (`optional`) string: cache_dir: (`optional`) string:
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used. Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
...@@ -331,57 +331,42 @@ class PreTrainedTokenizer(object): ...@@ -331,57 +331,42 @@ class PreTrainedTokenizer(object):
# Get the vocabulary from local files # Get the vocabulary from local files
logger.info( logger.info(
"Model name '{}' not found in model shortcut name list ({}). " "Model name '{}' not found in model shortcut name list ({}). "
"Assuming '{}' is a path or url to a directory containing tokenizer files.".format( "Assuming '{}' is a path, a model identifier, or url to a directory containing tokenizer files.".format(
pretrained_model_name_or_path, ", ".join(s3_models), pretrained_model_name_or_path pretrained_model_name_or_path, ", ".join(s3_models), pretrained_model_name_or_path
) )
) )
# Look for the tokenizer main vocabulary files if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
for file_id, file_name in cls.vocab_files_names.items(): if len(cls.vocab_files_names) > 1:
if os.path.isdir(pretrained_model_name_or_path): raise ValueError(
# If a directory is provided we look for the standard filenames "Calling {}.from_pretrained() with the path to a single file or url is not supported."
full_file_name = os.path.join(pretrained_model_name_or_path, file_name) "Use a model identifier or the path to a directory instead.".format(cls.__name__)
if not os.path.exists(full_file_name): )
logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) logger.warning(
full_file_name = None "Calling {}.from_pretrained() with the path to a single file or url is deprecated".format(
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): cls.__name__
# If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file)
full_file_name = pretrained_model_name_or_path
else:
full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name)
vocab_files[file_id] = full_file_name
# Look for the additional tokens files
additional_files_names = {
"added_tokens_file": ADDED_TOKENS_FILE,
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
}
# If a path to a file was provided, get the parent directory
saved_directory = pretrained_model_name_or_path
if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
saved_directory = os.path.dirname(saved_directory)
for file_id, file_name in additional_files_names.items():
full_file_name = os.path.join(saved_directory, file_name)
if not os.path.exists(full_file_name):
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
full_file_name = None
vocab_files[file_id] = full_file_name
if all(full_file_name is None for full_file_name in vocab_files.values()):
raise EnvironmentError(
"Model name '{}' was not found in tokenizers model name list ({}). "
"We assumed '{}' was a path or url to a directory containing vocabulary files "
"named {} but couldn't find such vocabulary files at this path or url.".format(
pretrained_model_name_or_path,
", ".join(s3_models),
pretrained_model_name_or_path,
list(cls.vocab_files_names.values()),
) )
) )
file_id = list(cls.vocab_files_names.keys())[0]
vocab_files[file_id] = pretrained_model_name_or_path
else:
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
additional_files_names = {
"added_tokens_file": ADDED_TOKENS_FILE,
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
}
# Look for the tokenizer main vocabulary files + the additional tokens files
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
if os.path.isdir(pretrained_model_name_or_path):
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
if not os.path.exists(full_file_name):
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
full_file_name = None
else:
full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name)
vocab_files[file_id] = full_file_name
# Get files from url, cache, or disk depending on the case # Get files from url, cache, or disk depending on the case
try: try:
...@@ -414,6 +399,18 @@ class PreTrainedTokenizer(object): ...@@ -414,6 +399,18 @@ class PreTrainedTokenizer(object):
raise EnvironmentError(msg) raise EnvironmentError(msg)
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
raise EnvironmentError(
"Model name '{}' was not found in tokenizers model name list ({}). "
"We assumed '{}' was a path, a model identifier, or url to a directory containing vocabulary files "
"named {} but couldn't find such vocabulary files at this path or url.".format(
pretrained_model_name_or_path,
", ".join(s3_models),
pretrained_model_name_or_path,
list(cls.vocab_files_names.values()),
)
)
for file_id, file_path in vocab_files.items(): for file_id, file_path in vocab_files.items():
if file_path == resolved_vocab_files[file_id]: if file_path == resolved_vocab_files[file_id]:
logger.info("loading file {}".format(file_path)) logger.info("loading file {}".format(file_path))
......
...@@ -56,3 +56,17 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -56,3 +56,17 @@ class AutoTokenizerTest(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER) tokenizer = AutoTokenizer.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
self.assertIsInstance(tokenizer, RobertaTokenizer) self.assertIsInstance(tokenizer, RobertaTokenizer)
self.assertEqual(len(tokenizer), 20) self.assertEqual(len(tokenizer), 20)
def test_tokenizer_identifier_with_correct_config(self):
logging.basicConfig(level=logging.INFO)
for tokenizer_class in [BertTokenizer, AutoTokenizer]:
tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased")
self.assertIsInstance(tokenizer, BertTokenizer)
self.assertEqual(tokenizer.basic_tokenizer.do_lower_case, False)
self.assertEqual(tokenizer.max_len, 512)
def test_tokenizer_identifier_non_existent(self):
logging.basicConfig(level=logging.INFO)
for tokenizer_class in [BertTokenizer, AutoTokenizer]:
with self.assertRaises(EnvironmentError):
_ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment