Unverified Commit 89693e17 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Remove special treatment for custom vocab files (#10637)



* Remove special path for custom vocab files

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Expand error message
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 6d9e11a1
...@@ -1601,69 +1601,51 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1601,69 +1601,51 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
logger.info("Offline mode: forcing local_files_only=True") logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True local_files_only = True
s3_models = list(cls.max_model_input_sizes.keys())
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
vocab_files = {} vocab_files = {}
init_configuration = {} init_configuration = {}
if pretrained_model_name_or_path in s3_models:
# Get the vocabulary from AWS S3 bucket
for file_id, map_list in cls.pretrained_vocab_files_map.items():
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
if (
cls.pretrained_init_configuration
and pretrained_model_name_or_path in cls.pretrained_init_configuration
):
init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path].copy()
else:
# Get the vocabulary from local files
logger.info(
"Model name '{}' not found in model shortcut name list ({}). "
"Assuming '{}' is a path, a model identifier, or url to a directory containing tokenizer files.".format(
pretrained_model_name_or_path, ", ".join(s3_models), pretrained_model_name_or_path
)
)
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
if len(cls.vocab_files_names) > 1: if len(cls.vocab_files_names) > 1:
raise ValueError( raise ValueError(
"Calling {}.from_pretrained() with the path to a single file or url is not supported." f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
"Use a model identifier or the path to a directory instead.".format(cls.__name__) "supported for this tokenizer. Use a model identifier or the path to a directory instead."
)
logger.warning(
"Calling {}.from_pretrained() with the path to a single file or url is deprecated".format(
cls.__name__
)
) )
file_id = list(cls.vocab_files_names.keys())[0] warnings.warn(
vocab_files[file_id] = pretrained_model_name_or_path f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated and "
else: "won't be possible anymore in v5. Use a model identifier or the path to a directory instead.",
# At this point pretrained_model_name_or_path is either a directory or a model identifier name FutureWarning,
additional_files_names = { )
"added_tokens_file": ADDED_TOKENS_FILE, file_id = list(cls.vocab_files_names.keys())[0]
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, vocab_files[file_id] = pretrained_model_name_or_path
"tokenizer_config_file": TOKENIZER_CONFIG_FILE, else:
"tokenizer_file": FULL_TOKENIZER_FILE, # At this point pretrained_model_name_or_path is either a directory or a model identifier name
} additional_files_names = {
# Look for the tokenizer files "added_tokens_file": ADDED_TOKENS_FILE,
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items(): "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
if os.path.isdir(pretrained_model_name_or_path): "tokenizer_config_file": TOKENIZER_CONFIG_FILE,
if subfolder is not None: "tokenizer_file": FULL_TOKENIZER_FILE,
full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name) }
else: # Look for the tokenizer files
full_file_name = os.path.join(pretrained_model_name_or_path, file_name) for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
if not os.path.exists(full_file_name): if os.path.isdir(pretrained_model_name_or_path):
logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) if subfolder is not None:
full_file_name = None full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name)
else: else:
full_file_name = hf_bucket_url( full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
pretrained_model_name_or_path, if not os.path.exists(full_file_name):
filename=file_name, logger.info(f"Didn't find file {full_file_name}. We won't load it.")
subfolder=subfolder, full_file_name = None
revision=revision, else:
mirror=None, full_file_name = hf_bucket_url(
) pretrained_model_name_or_path,
filename=file_name,
subfolder=subfolder,
revision=revision,
mirror=None,
)
vocab_files[file_id] = full_file_name vocab_files[file_id] = full_file_name
# Get files from url, cache, or disk depending on the case # Get files from url, cache, or disk depending on the case
resolved_vocab_files = {} resolved_vocab_files = {}
...@@ -1673,21 +1655,21 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1673,21 +1655,21 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
resolved_vocab_files[file_id] = None resolved_vocab_files[file_id] = None
else: else:
try: try:
try: resolved_vocab_files[file_id] = cached_path(
resolved_vocab_files[file_id] = cached_path( file_path,
file_path, cache_dir=cache_dir,
cache_dir=cache_dir, force_download=force_download,
force_download=force_download, proxies=proxies,
proxies=proxies, resume_download=resume_download,
resume_download=resume_download, local_files_only=local_files_only,
local_files_only=local_files_only, use_auth_token=use_auth_token,
use_auth_token=use_auth_token, )
)
except FileNotFoundError as error: except FileNotFoundError as error:
if local_files_only: if local_files_only:
unresolved_files.append(file_id) unresolved_files.append(file_id)
else: else:
raise error raise error
except requests.exceptions.HTTPError as err: except requests.exceptions.HTTPError as err:
if "404 Client Error" in str(err): if "404 Client Error" in str(err):
...@@ -1715,9 +1697,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1715,9 +1697,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
continue continue
if file_path == resolved_vocab_files[file_id]: if file_path == resolved_vocab_files[file_id]:
logger.info("loading file {}".format(file_path)) logger.info(f"loading file {file_path}")
else: else:
logger.info("loading file {} from cache at {}".format(file_path, resolved_vocab_files[file_id])) logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
return cls._from_pretrained( return cls._from_pretrained(
resolved_vocab_files, pretrained_model_name_or_path, init_configuration, *init_inputs, **kwargs resolved_vocab_files, pretrained_model_name_or_path, init_configuration, *init_inputs, **kwargs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment