Unverified Commit 6cb0a6f0 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Partial local tokenizer load (#9807)



* Allow partial loading of a cached tokenizer

* Warning > Info

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Raise error if not local_files_only
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 25fcb5c1
......@@ -1239,7 +1239,7 @@ def get_from_cache(
# the models might've been found if local_files_only=False
# Notify the user about that
if local_files_only:
raise ValueError(
raise FileNotFoundError(
"Cannot find the requested files in the cached path and outgoing traffic has been"
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False."
......
......@@ -1730,20 +1730,28 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
# Get files from url, cache, or disk depending on the case
resolved_vocab_files = {}
unresolved_files = []
for file_id, file_path in vocab_files.items():
if file_path is None:
resolved_vocab_files[file_id] = None
else:
try:
resolved_vocab_files[file_id] = cached_path(
file_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
try:
resolved_vocab_files[file_id] = cached_path(
file_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
except FileNotFoundError as error:
if local_files_only:
unresolved_files.append(file_id)
else:
raise error
except requests.exceptions.HTTPError as err:
if "404 Client Error" in str(err):
logger.debug(err)
......@@ -1751,6 +1759,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
else:
raise err
if len(unresolved_files) > 0:
logger.info(
f"Can't load following files from cache: {unresolved_files} and cannot check if these "
"files are necessary for the tokenizer to operate."
)
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
msg = (
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
......@@ -1760,6 +1774,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
raise EnvironmentError(msg)
for file_id, file_path in vocab_files.items():
if file_id not in resolved_vocab_files:
continue
if file_path == resolved_vocab_files[file_id]:
logger.info("loading file {}".format(file_path))
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment