Unverified Commit 6cb0a6f0 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Partial local tokenizer load (#9807)



* Allow partial loading of a cached tokenizer

* Warning > Info

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Raise error if not local_files_only
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 25fcb5c1
...@@ -1239,7 +1239,7 @@ def get_from_cache( ...@@ -1239,7 +1239,7 @@ def get_from_cache(
# the models might've been found if local_files_only=False # the models might've been found if local_files_only=False
# Notify the user about that # Notify the user about that
if local_files_only: if local_files_only:
raise ValueError( raise FileNotFoundError(
"Cannot find the requested files in the cached path and outgoing traffic has been" "Cannot find the requested files in the cached path and outgoing traffic has been"
" disabled. To enable model look-ups and downloads online, set 'local_files_only'" " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False." " to False."
......
...@@ -1730,10 +1730,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1730,10 +1730,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
# Get files from url, cache, or disk depending on the case # Get files from url, cache, or disk depending on the case
resolved_vocab_files = {} resolved_vocab_files = {}
unresolved_files = []
for file_id, file_path in vocab_files.items(): for file_id, file_path in vocab_files.items():
if file_path is None: if file_path is None:
resolved_vocab_files[file_id] = None resolved_vocab_files[file_id] = None
else: else:
try:
try: try:
resolved_vocab_files[file_id] = cached_path( resolved_vocab_files[file_id] = cached_path(
file_path, file_path,
...@@ -1744,6 +1746,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1744,6 +1746,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
except FileNotFoundError as error:
if local_files_only:
unresolved_files.append(file_id)
else:
raise error
except requests.exceptions.HTTPError as err: except requests.exceptions.HTTPError as err:
if "404 Client Error" in str(err): if "404 Client Error" in str(err):
logger.debug(err) logger.debug(err)
...@@ -1751,6 +1759,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1751,6 +1759,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
else: else:
raise err raise err
if len(unresolved_files) > 0:
logger.info(
f"Can't load following files from cache: {unresolved_files} and cannot check if these "
"files are necessary for the tokenizer to operate."
)
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()): if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
msg = ( msg = (
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
...@@ -1760,6 +1774,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1760,6 +1774,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
raise EnvironmentError(msg) raise EnvironmentError(msg)
for file_id, file_path in vocab_files.items(): for file_id, file_path in vocab_files.items():
if file_id not in resolved_vocab_files:
continue
if file_path == resolved_vocab_files[file_id]: if file_path == resolved_vocab_files[file_id]:
logger.info("loading file {}".format(file_path)) logger.info("loading file {}".format(file_path))
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment