Commit 23a2cea8 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Tokenizer.from_pretrained: fetch all possible files remotely

parent 99f9243d
......@@ -200,6 +200,8 @@ class PretrainedConfig(object):
resume_download=resume_download,
)
# Load config dict
if resolved_config_file is None:
raise EnvironmentError
config_dict = cls._dict_from_json_file(resolved_config_file)
except EnvironmentError:
......@@ -210,7 +212,7 @@ class PretrainedConfig(object):
else:
msg = (
"Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url to a configuration file named {} or "
"We assumed '{}' was a path, a model identifier, or url to a configuration file named {} or "
"a directory containing such a file but couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path, config_file, CONFIG_NAME,
)
......
......@@ -14,6 +14,7 @@ import tempfile
from contextlib import contextmanager
from functools import partial, wraps
from hashlib import sha256
from typing import Optional
from urllib.parse import urlparse
import boto3
......@@ -122,7 +123,7 @@ def is_remote_url(url_or_filename):
return parsed.scheme in ("http", "https", "s3")
def hf_bucket_url(identifier, postfix=None, cdn=False):
def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
if postfix is None:
return "/".join((endpoint, identifier))
......@@ -182,7 +183,7 @@ def filename_to_url(filename, cache_dir=None):
def cached_path(
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
):
) -> Optional[str]:
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
......@@ -193,6 +194,10 @@ def cached_path(
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
Local path (string) otherwise
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
......@@ -306,10 +311,14 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
def get_from_cache(
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
):
) -> Optional[str]:
"""
Given a URL, look for the corresponding dataset in the local cache.
Given a URL, look for the corresponding file in the local cache.
If it's not there, download it. Then return the path to the cached file.
Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
Local path (string) otherwise
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
......@@ -336,16 +345,25 @@ def get_from_cache(
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# If we don't have a connection (etag is None) and can't identify the file
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
matching_files = [
file
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
if not file.endswith(".json") and not file.endswith(".lock")
]
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
if etag is None:
if os.path.exists(cache_path):
return cache_path
else:
matching_files = [
file
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
if not file.endswith(".json") and not file.endswith(".lock")
]
if len(matching_files) > 0:
return os.path.join(cache_dir, matching_files[-1])
else:
return None
# From now on, etag is not None.
if os.path.exists(cache_path) and not force_download:
return cache_path
# Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + ".lock"
......@@ -368,29 +386,26 @@ def get_from_cache(
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
resume_size = 0
if etag is not None and (not os.path.exists(cache_path) or force_download):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
logger.info(
"%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name
)
# GET file object
if url.startswith("s3://"):
if resume_download:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies)
else:
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
logger.info("storing %s in cache at %s", url, cache_path)
os.rename(temp_file.name, cache_path)
logger.info("creating metadata file for %s", cache_path)
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
json.dump(meta, meta_file)
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
if resume_download:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies)
else:
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
logger.info("storing %s in cache at %s", url, cache_path)
os.rename(temp_file.name, cache_path)
logger.info("creating metadata file for %s", cache_path)
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
json.dump(meta, meta_file)
return cache_path
......@@ -264,7 +264,7 @@ class PreTrainedTokenizer(object):
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
- (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
- (not applicable to all derived classes, deprecated) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
cache_dir: (`optional`) string:
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
......@@ -331,57 +331,42 @@ class PreTrainedTokenizer(object):
# Get the vocabulary from local files
logger.info(
"Model name '{}' not found in model shortcut name list ({}). "
"Assuming '{}' is a path or url to a directory containing tokenizer files.".format(
"Assuming '{}' is a path, a model identifier, or url to a directory containing tokenizer files.".format(
pretrained_model_name_or_path, ", ".join(s3_models), pretrained_model_name_or_path
)
)
# Look for the tokenizer main vocabulary files
for file_id, file_name in cls.vocab_files_names.items():
if os.path.isdir(pretrained_model_name_or_path):
# If a directory is provided we look for the standard filenames
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
if not os.path.exists(full_file_name):
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
full_file_name = None
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
# If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file)
full_file_name = pretrained_model_name_or_path
else:
full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name)
vocab_files[file_id] = full_file_name
# Look for the additional tokens files
additional_files_names = {
"added_tokens_file": ADDED_TOKENS_FILE,
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
}
# If a path to a file was provided, get the parent directory
saved_directory = pretrained_model_name_or_path
if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
saved_directory = os.path.dirname(saved_directory)
for file_id, file_name in additional_files_names.items():
full_file_name = os.path.join(saved_directory, file_name)
if not os.path.exists(full_file_name):
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
full_file_name = None
vocab_files[file_id] = full_file_name
if all(full_file_name is None for full_file_name in vocab_files.values()):
raise EnvironmentError(
"Model name '{}' was not found in tokenizers model name list ({}). "
"We assumed '{}' was a path or url to a directory containing vocabulary files "
"named {} but couldn't find such vocabulary files at this path or url.".format(
pretrained_model_name_or_path,
", ".join(s3_models),
pretrained_model_name_or_path,
list(cls.vocab_files_names.values()),
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
if len(cls.vocab_files_names) > 1:
raise ValueError(
"Calling {}.from_pretrained() with the path to a single file or url is not supported."
"Use a model identifier or the path to a directory instead.".format(cls.__name__)
)
logger.warning(
"Calling {}.from_pretrained() with the path to a single file or url is deprecated".format(
cls.__name__
)
)
file_id = list(cls.vocab_files_names.keys())[0]
vocab_files[file_id] = pretrained_model_name_or_path
else:
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
additional_files_names = {
"added_tokens_file": ADDED_TOKENS_FILE,
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
}
# Look for the tokenizer main vocabulary files + the additional tokens files
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
if os.path.isdir(pretrained_model_name_or_path):
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
if not os.path.exists(full_file_name):
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
full_file_name = None
else:
full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name)
vocab_files[file_id] = full_file_name
# Get files from url, cache, or disk depending on the case
try:
......@@ -414,6 +399,18 @@ class PreTrainedTokenizer(object):
raise EnvironmentError(msg)
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
raise EnvironmentError(
"Model name '{}' was not found in tokenizers model name list ({}). "
"We assumed '{}' was a path, a model identifier, or url to a directory containing vocabulary files "
"named {} but couldn't find such vocabulary files at this path or url.".format(
pretrained_model_name_or_path,
", ".join(s3_models),
pretrained_model_name_or_path,
list(cls.vocab_files_names.values()),
)
)
for file_id, file_path in vocab_files.items():
if file_path == resolved_vocab_files[file_id]:
logger.info("loading file {}".format(file_path))
......
......@@ -56,3 +56,17 @@ class AutoTokenizerTest(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
self.assertIsInstance(tokenizer, RobertaTokenizer)
self.assertEqual(len(tokenizer), 20)
def test_tokenizer_identifier_with_correct_config(self):
logging.basicConfig(level=logging.INFO)
for tokenizer_class in [BertTokenizer, AutoTokenizer]:
tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased")
self.assertIsInstance(tokenizer, BertTokenizer)
self.assertEqual(tokenizer.basic_tokenizer.do_lower_case, False)
self.assertEqual(tokenizer.max_len, 512)
def test_tokenizer_identifier_non_existent(self):
logging.basicConfig(level=logging.INFO)
for tokenizer_class in [BertTokenizer, AutoTokenizer]:
with self.assertRaises(EnvironmentError):
_ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment