Unverified Commit adb70eda authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

AutoTokenizer: infer the class from the tokenizer config if possible (#12208)



* AutoTokenizer: infer the class from the tokenizer config if possible

* Add tests

* Update src/transformers/models/auto/tokenization_auto.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 0daadc19
...@@ -14,12 +14,21 @@ ...@@ -14,12 +14,21 @@
# limitations under the License. # limitations under the License.
""" Auto Tokenizer class. """ """ Auto Tokenizer class. """
import json
import os
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Optional, Union
from ... import GPTNeoConfig from ... import GPTNeoConfig
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import is_sentencepiece_available, is_tokenizers_available from ...file_utils import (
cached_path,
hf_bucket_url,
is_offline_mode,
is_sentencepiece_available,
is_tokenizers_available,
)
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ...utils import logging from ...utils import logging
from ..bart.tokenization_bart import BartTokenizer from ..bart.tokenization_bart import BartTokenizer
from ..bert.tokenization_bert import BertTokenizer from ..bert.tokenization_bert import BertTokenizer
...@@ -323,6 +332,105 @@ def tokenizer_class_from_name(class_name: str): ...@@ -323,6 +332,105 @@ def tokenizer_class_from_name(class_name: str):
return c return c
def get_tokenizer_config(
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
):
"""
Loads the tokenizer configuration from a pretrained model tokenizer configuration.
Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
This can be either:
- a string, the `model id` of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a configuration file saved using the
:func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, will only try to load the tokenizer configuration from local files.
.. note::
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
Returns:
:obj:`Dict`: The configuration of the tokenizer.
Examples::
# Download configuration from huggingface.co and cache.
tokenizer_config = get_tokenizer_config("bert-base-uncased")
# This model does not have a tokenizer config so the result will be an empty dict.
tokenizer_config = get_tokenizer_config("xlm-roberta-base")
# Save a pretrained tokenizer locally and you can reload its config
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer.save_pretrained("tokenizer-test")
tokenizer_config = get_tokenizer_config("tokenizer-test")
"""
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
else:
config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=TOKENIZER_CONFIG_FILE, revision=revision, mirror=None
)
try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(
config_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
except EnvironmentError:
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
return {}
with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)
class AutoTokenizer: class AutoTokenizer:
r""" r"""
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
...@@ -408,18 +516,27 @@ class AutoTokenizer: ...@@ -408,18 +516,27 @@ class AutoTokenizer:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
use_fast = kwargs.pop("use_fast", True) use_fast = kwargs.pop("use_fast", True)
if config.tokenizer_class is not None: # First, let's try to use the tokenizer_config file to get the tokenizer class.
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
# If that did not work, let's try to use the config.
if config_tokenizer_class is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config_tokenizer_class = config.tokenizer_class
# If we have the tokenizer class from the tokenizer config or the model config we're good!
if config_tokenizer_class is not None:
tokenizer_class = None tokenizer_class = None
if use_fast and not config.tokenizer_class.endswith("Fast"): if use_fast and not config_tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config.tokenizer_class}Fast" tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None: if tokenizer_class is None:
tokenizer_class_candidate = config.tokenizer_class tokenizer_class_candidate = config_tokenizer_class
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None: if tokenizer_class is None:
...@@ -428,6 +545,7 @@ class AutoTokenizer: ...@@ -428,6 +545,7 @@ class AutoTokenizer:
) )
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
# Otherwise we have to be creative.
# if model is an encoder decoder, the encoder tokenizer class is used by default # if model is an encoder decoder, the encoder tokenizer class is used by default
if isinstance(config, EncoderDecoderConfig): if isinstance(config, EncoderDecoderConfig):
if type(config.decoder) is not type(config.encoder): # noqa: E721 if type(config.decoder) is not type(config.encoder): # noqa: E721
......
...@@ -1745,6 +1745,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1745,6 +1745,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if tokenizer_config_file is not None: if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
init_kwargs = json.load(tokenizer_config_handle) init_kwargs = json.load(tokenizer_config_handle)
init_kwargs.pop("tokenizer_class", None)
saved_init_inputs = init_kwargs.pop("init_inputs", ()) saved_init_inputs = init_kwargs.pop("init_inputs", ())
if not init_inputs: if not init_inputs:
init_inputs = saved_init_inputs init_inputs = saved_init_inputs
...@@ -1920,6 +1921,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1920,6 +1921,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization # add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization
tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True) tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True)
# Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained
tokenizer_class = self.__class__.__name__
# Remove the Fast at the end unless we have a special `PreTrainedTokenizerFast`
if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast":
tokenizer_class = tokenizer_class[:-4]
tokenizer_config["tokenizer_class"] = tokenizer_class
with open(tokenizer_config_file, "w", encoding="utf-8") as f: with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False)) f.write(json.dumps(tokenizer_config, ensure_ascii=False))
logger.info(f"tokenizer config file saved in {tokenizer_config_file}") logger.info(f"tokenizer config file saved in {tokenizer_config_file}")
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import tempfile
import unittest import unittest
from transformers import ( from transformers import (
...@@ -29,7 +29,7 @@ from transformers import ( ...@@ -29,7 +29,7 @@ from transformers import (
RobertaTokenizerFast, RobertaTokenizerFast,
) )
from transformers.models.auto.configuration_auto import AutoConfig from transformers.models.auto.configuration_auto import AutoConfig
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING, get_tokenizer_config
from transformers.models.roberta.configuration_roberta import RobertaConfig from transformers.models.roberta.configuration_roberta import RobertaConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
DUMMY_DIFF_TOKENIZER_IDENTIFIER, DUMMY_DIFF_TOKENIZER_IDENTIFIER,
...@@ -129,3 +129,34 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -129,3 +129,34 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertEqual(tokenizer.vocab_size, 30000) self.assertEqual(tokenizer.vocab_size, 30000)
self.assertEqual(tokenizer.unk_token, "[UNK]") self.assertEqual(tokenizer.unk_token, "[UNK]")
self.assertEqual(tokenizer.padding_side, "right") self.assertEqual(tokenizer.padding_side, "right")
def test_auto_tokenizer_from_local_folder(self):
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir)
tokenizer2 = AutoTokenizer.from_pretrained(tmp_dir)
self.assertIsInstance(tokenizer2, tokenizer.__class__)
self.assertEqual(tokenizer2.vocab_size, 12)
def test_get_tokenizer_config(self):
# Check we can load the tokenizer config of an online model.
config = get_tokenizer_config("bert-base-cased")
# If we ever update bert-base-cased tokenizer config, this dict here will need to be updated.
self.assertEqual(config, {"do_lower_case": False})
# This model does not have a tokenizer_config so we get back an empty dict.
config = get_tokenizer_config(SMALL_MODEL_IDENTIFIER)
self.assertDictEqual(config, {})
# A tokenizer saved with `save_pretrained` always creates a tokenizer config.
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir)
config = get_tokenizer_config(tmp_dir)
# Check the class of the tokenizer was properly saved (note that it always saves the slow class).
self.assertEqual(config["tokenizer_class"], "BertTokenizer")
# Check other keys just to make sure the config was properly saved /reloaded.
self.assertEqual(config["name_or_path"], SMALL_MODEL_IDENTIFIER)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment