tokenization_auto.py 30 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" Auto Tokenizer class."""
thomwolf's avatar
thomwolf committed
16

17
import importlib
18
19
import json
import os
20
from collections import OrderedDict
21
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
thomwolf's avatar
thomwolf committed
22

Sylvain Gugger's avatar
Sylvain Gugger committed
23
from ...configuration_utils import PretrainedConfig
24
from ...dynamic_module_utils import get_class_from_dynamic_module
25
from ...tokenization_utils import PreTrainedTokenizer
26
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
27
from ...tokenization_utils_fast import PreTrainedTokenizerFast
28
from ...utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available, logging
29
30
from ..encoder_decoder import EncoderDecoderConfig
from .auto_factory import _LazyAutoMapping
31
from .configuration_auto import (
32
    CONFIG_MAPPING_NAMES,
33
    AutoConfig,
34
    config_class_to_model_type,
35
    model_type_to_module_name,
36
    replace_list_option_in_docstrings,
37
)
Aymeric Augustin's avatar
Aymeric Augustin committed
38

thomwolf's avatar
thomwolf committed
39

Lysandre Debut's avatar
Lysandre Debut committed
40
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
41

42
43
44
45
46
47
48
if TYPE_CHECKING:
    # This significantly improves completion suggestion performance when
    # the transformers package is used with Microsoft's Pylance language server.
    TOKENIZER_MAPPING_NAMES: OrderedDict[str, Tuple[Optional[str], Optional[str]]] = OrderedDict()
else:
    TOKENIZER_MAPPING_NAMES = OrderedDict(
        [
Gunjan Chhablani's avatar
Gunjan Chhablani committed
49
            ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
50
            ("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
51
            ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
52
53
            ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
            ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
54
            (
55
56
57
58
59
                "t5",
                (
                    "T5Tokenizer" if is_sentencepiece_available() else None,
                    "T5TokenizerFast" if is_tokenizers_available() else None,
                ),
60
61
            ),
            (
62
63
64
65
66
                "mt5",
                (
                    "MT5Tokenizer" if is_sentencepiece_available() else None,
                    "MT5TokenizerFast" if is_tokenizers_available() else None,
                ),
67
            ),
68
69
            ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
            ("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
70
            (
71
72
73
74
75
                "albert",
                (
                    "AlbertTokenizer" if is_sentencepiece_available() else None,
                    "AlbertTokenizerFast" if is_tokenizers_available() else None,
                ),
76
77
            ),
            (
78
79
80
81
82
                "camembert",
                (
                    "CamembertTokenizer" if is_sentencepiece_available() else None,
                    "CamembertTokenizerFast" if is_tokenizers_available() else None,
                ),
83
84
            ),
            (
85
86
87
88
89
                "pegasus",
                (
                    "PegasusTokenizer" if is_sentencepiece_available() else None,
                    "PegasusTokenizerFast" if is_tokenizers_available() else None,
                ),
90
91
            ),
            (
92
93
94
95
96
                "mbart",
                (
                    "MBartTokenizer" if is_sentencepiece_available() else None,
                    "MBartTokenizerFast" if is_tokenizers_available() else None,
                ),
97
98
            ),
            (
99
100
101
102
103
                "xlm-roberta",
                (
                    "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
                    "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
                ),
104
            ),
105
106
            ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
            ("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
107
            ("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
NielsRogge's avatar
NielsRogge committed
108
            ("tapex", ("TapexTokenizer", None)),
109
110
111
            ("bart", ("BartTokenizer", "BartTokenizerFast")),
            ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
            ("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
112
            (
113
114
115
116
117
                "reformer",
                (
                    "ReformerTokenizer" if is_sentencepiece_available() else None,
                    "ReformerTokenizerFast" if is_tokenizers_available() else None,
                ),
118
            ),
119
120
121
122
            ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
            ("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
            ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
            ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
123
            ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
124
            ("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
125
            (
126
127
128
129
130
                "dpr",
                (
                    "DPRQuestionEncoderTokenizer",
                    "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
                ),
131
132
            ),
            (
133
134
                "squeezebert",
                ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
135
            ),
136
137
138
139
            ("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
            ("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
            ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
            ("transfo-xl", ("TransfoXLTokenizer", None)),
140
            (
141
142
143
144
145
                "xlnet",
                (
                    "XLNetTokenizer" if is_sentencepiece_available() else None,
                    "XLNetTokenizerFast" if is_tokenizers_available() else None,
                ),
146
            ),
147
148
149
150
151
152
153
154
155
156
            ("flaubert", ("FlaubertTokenizer", None)),
            ("xlm", ("XLMTokenizer", None)),
            ("ctrl", ("CTRLTokenizer", None)),
            ("fsmt", ("FSMTTokenizer", None)),
            ("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
            ("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
            ("deberta-v2", ("DebertaV2Tokenizer" if is_sentencepiece_available() else None, None)),
            ("rag", ("RagTokenizer", None)),
            ("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
            ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
157
            ("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
158
159
160
161
162
163
            ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
            ("prophetnet", ("ProphetNetTokenizer", None)),
            ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
            ("tapas", ("TapasTokenizer", None)),
            ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
            ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
164
            (
165
166
167
168
169
                "big_bird",
                (
                    "BigBirdTokenizer" if is_sentencepiece_available() else None,
                    "BigBirdTokenizerFast" if is_tokenizers_available() else None,
                ),
170
            ),
171
            ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
172
            ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
173
174
175
176
            ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
            ("hubert", ("Wav2Vec2CTCTokenizer", None)),
            ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
            ("luke", ("LukeTokenizer", None)),
Ryokan RI's avatar
Ryokan RI committed
177
            ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
178
179
180
181
            ("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
            ("canine", ("CanineTokenizer", None)),
            ("bertweet", ("BertweetTokenizer", None)),
            ("bert-japanese", ("BertJapaneseTokenizer", None)),
Ori Ram's avatar
Ori Ram committed
182
            ("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
183
            ("byt5", ("ByT5Tokenizer", None)),
184
            (
185
186
187
188
189
                "cpm",
                (
                    "CpmTokenizer" if is_sentencepiece_available() else None,
                    "CpmTokenizerFast" if is_tokenizers_available() else None,
                ),
190
            ),
191
192
            ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
            ("phobert", ("PhobertTokenizer", None)),
193
            ("bartpho", ("BartphoTokenizer", None)),
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            (
                "barthez",
                (
                    "BarthezTokenizer" if is_sentencepiece_available() else None,
                    "BarthezTokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
            (
                "mbart50",
                (
                    "MBart50Tokenizer" if is_sentencepiece_available() else None,
                    "MBart50TokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
208
209
210
211
212
213
214
            (
                "rembert",
                (
                    "RemBertTokenizer" if is_sentencepiece_available() else None,
                    "RemBertTokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
215
216
217
218
219
220
221
            (
                "clip",
                (
                    "CLIPTokenizer",
                    "CLIPTokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
222
            ("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
223
224
225
226
227
228
229
            (
                "perceiver",
                (
                    "PerceiverTokenizer",
                    None,
                ),
            ),
230
231
232
233
234
235
236
            (
                "xglm",
                (
                    "XGLMTokenizer" if is_sentencepiece_available() else None,
                    "XGLMTokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
237
238
        ]
    )
239

240
241
242
TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)

CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
243

244

245
def tokenizer_class_from_name(class_name: str):
246
247
248
249
250
    if class_name == "PreTrainedTokenizerFast":
        return PreTrainedTokenizerFast

    for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
        if class_name in tokenizers:
251
            module_name = model_type_to_module_name(module_name)
252

253
254
            module = importlib.import_module(f".{module_name}", "transformers.models")
            return getattr(module, class_name)
255

256
257
258
259
260
    for config, tokenizers in TOKENIZER_MAPPING._extra_content.items():
        for tokenizer in tokenizers:
            if getattr(tokenizer, "__name__", None) == class_name:
                return tokenizer

261
    return None
262
263


264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def get_tokenizer_config(
    pretrained_model_name_or_path: Union[str, os.PathLike],
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    force_download: bool = False,
    resume_download: bool = False,
    proxies: Optional[Dict[str, str]] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
    revision: Optional[str] = None,
    local_files_only: bool = False,
    **kwargs,
):
    """
    Loads the tokenizer configuration from a pretrained model tokenizer configuration.

    Args:
279
        pretrained_model_name_or_path (`str` or `os.PathLike`):
280
281
            This can be either:

282
            - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
Sylvain Gugger's avatar
Sylvain Gugger committed
283
284
              huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
              under a user or organization name, like `dbmdz/bert-base-german-cased`.
285
286
            - a path to a *directory* containing a configuration file saved using the
              [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
287

288
        cache_dir (`str` or `os.PathLike`, *optional*):
289
290
            Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
            cache should not be used.
291
        force_download (`bool`, *optional*, defaults to `False`):
292
293
            Whether or not to force to (re-)download the configuration files and override the cached versions if they
            exist.
294
        resume_download (`bool`, *optional*, defaults to `False`):
295
            Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
296
        proxies (`Dict[str, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
297
298
            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
            'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
299
        use_auth_token (`str` or *bool*, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
300
301
            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
            when running `transformers-cli login` (stored in `~/.huggingface`).
302
        revision (`str`, *optional*, defaults to `"main"`):
303
            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
304
            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
305
            identifier allowed by git.
306
307
        local_files_only (`bool`, *optional*, defaults to `False`):
            If `True`, will only try to load the tokenizer configuration from local files.
308

309
    <Tip>
310

311
    Passing `use_auth_token=True` is required when you want to use a private model.
312

313
    </Tip>
314
315

    Returns:
316
        `Dict`: The configuration of the tokenizer.
317

318
    Examples:
319

320
321
322
323
324
    ```python
    # Download configuration from huggingface.co and cache.
    tokenizer_config = get_tokenizer_config("bert-base-uncased")
    # This model does not have a tokenizer config so the result will be an empty dict.
    tokenizer_config = get_tokenizer_config("xlm-roberta-base")
325

326
327
    # Save a pretrained tokenizer locally and you can reload its config
    from transformers import AutoTokenizer
328

329
330
331
332
    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    tokenizer.save_pretrained("tokenizer-test")
    tokenizer_config = get_tokenizer_config("tokenizer-test")
    ```"""
333
334
335
336
337
338
339
340
341
342
343
344
    resolved_config_file = get_file_from_repo(
        pretrained_model_name_or_path,
        TOKENIZER_CONFIG_FILE,
        cache_dir=cache_dir,
        force_download=force_download,
        resume_download=resume_download,
        proxies=proxies,
        use_auth_token=use_auth_token,
        revision=revision,
        local_files_only=local_files_only,
    )
    if resolved_config_file is None:
345
346
347
348
349
350
351
        logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
        return {}

    with open(resolved_config_file, encoding="utf-8") as reader:
        return json.load(reader)


Julien Chaumond's avatar
Julien Chaumond committed
352
class AutoTokenizer:
353
    r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
354
    This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
355
    created with the [`AutoTokenizer.from_pretrained`] class method.
thomwolf's avatar
thomwolf committed
356

357
    This class cannot be instantiated directly using `__init__()` (throws an error).
thomwolf's avatar
thomwolf committed
358
    """
359

thomwolf's avatar
thomwolf committed
360
    def __init__(self):
361
362
363
364
        raise EnvironmentError(
            "AutoTokenizer is designed to be instantiated "
            "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
        )
thomwolf's avatar
thomwolf committed
365
366

    @classmethod
367
    @replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
thomwolf's avatar
thomwolf committed
368
    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
369
370
        r"""
        Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.
thomwolf's avatar
thomwolf committed
371

Sylvain Gugger's avatar
Sylvain Gugger committed
372
373
374
        The tokenizer class to instantiate is selected based on the `model_type` property of the config object (either
        passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
        falling back to using pattern matching on `pretrained_model_name_or_path`:
375

376
        List options
thomwolf's avatar
thomwolf committed
377
378

        Params:
379
            pretrained_model_name_or_path (`str` or `os.PathLike`):
380
381
                Can be either:

382
                    - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
Sylvain Gugger's avatar
Sylvain Gugger committed
383
384
                      Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
                      user or organization name, like `dbmdz/bert-base-german-cased`.
385
                    - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
Sylvain Gugger's avatar
Sylvain Gugger committed
386
                      using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
387
                    - A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
388
                      single vocabulary file (like Bert or XLNet), e.g.: `./my_model_directory/vocab.txt`. (Not
Sylvain Gugger's avatar
Sylvain Gugger committed
389
                      applicable to all derived classes)
390
391
392
            inputs (additional positional arguments, *optional*):
                Will be passed along to the Tokenizer `__init__()` method.
            config ([`PretrainedConfig`], *optional*)
393
                The configuration object used to dertermine the tokenizer class to instantiate.
394
            cache_dir (`str` or `os.PathLike`, *optional*):
395
396
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
397
            force_download (`bool`, *optional*, defaults to `False`):
398
399
                Whether or not to force the (re-)download the model weights and configuration files and override the
                cached versions if they exist.
400
            resume_download (`bool`, *optional*, defaults to `False`):
401
402
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
403
            proxies (`Dict[str, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
404
405
                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
406
            revision (`str`, *optional*, defaults to `"main"`):
Julien Chaumond's avatar
Julien Chaumond committed
407
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
408
                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
Julien Chaumond's avatar
Julien Chaumond committed
409
                identifier allowed by git.
410
            subfolder (`str`, *optional*):
411
412
                In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
                facebook/rag-token-base), specify it here.
413
            use_fast (`bool`, *optional*, defaults to `True`):
414
                Whether or not to try to load the fast version of the tokenizer.
415
            tokenizer_type (`str`, *optional*):
416
                Tokenizer type to be loaded.
417
            trust_remote_code (`bool`, *optional*, defaults to `False`):
418
                Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
Sylvain Gugger's avatar
Sylvain Gugger committed
419
420
                should only be set to `True` for repositories you trust and in which you have read the code, as it will
                execute code present on the Hub on your local machine.
421
422
            kwargs (additional keyword arguments, *optional*):
                Will be passed to the Tokenizer `__init__()` method. Can be used to set special tokens like
Sylvain Gugger's avatar
Sylvain Gugger committed
423
424
                `bos_token`, `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`,
                `additional_special_tokens`. See parameters in the `__init__()` for more details.
thomwolf's avatar
thomwolf committed
425

426
        Examples:
427

428
429
        ```python
        >>> from transformers import AutoTokenizer
430

431
        >>> # Download vocabulary from huggingface.co and cache.
Sylvain Gugger's avatar
Sylvain Gugger committed
432
        >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
433

434
        >>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
Sylvain Gugger's avatar
Sylvain Gugger committed
435
        >>> tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
thomwolf's avatar
thomwolf committed
436

437
        >>> # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
Sylvain Gugger's avatar
Sylvain Gugger committed
438
        >>> tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")
439
        ```"""
440
        config = kwargs.pop("config", None)
441
        kwargs["_from_auto"] = True
442

443
        use_fast = kwargs.pop("use_fast", True)
444
        tokenizer_type = kwargs.pop("tokenizer_type", None)
445
        trust_remote_code = kwargs.pop("trust_remote_code", False)
446

447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
        # First, let's see whether the tokenizer_type is passed so that we can leverage it
        if tokenizer_type is not None:
            tokenizer_class = None
            tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None)

            if tokenizer_class_tuple is None:
                raise ValueError(
                    f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of "
                    f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES.keys())}."
                )

            tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple

            if use_fast and tokenizer_fast_class_name is not None:
                tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name)

            if tokenizer_class is None:
                tokenizer_class = tokenizer_class_from_name(tokenizer_class_name)

            if tokenizer_class is None:
                raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.")

            return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

        # Next, let's try to use the tokenizer_config file to get the tokenizer class.
472
473
        tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
        config_tokenizer_class = tokenizer_config.get("tokenizer_class")
474
475
476
477
478
479
480
        tokenizer_auto_map = None
        if "auto_map" in tokenizer_config:
            if isinstance(tokenizer_config["auto_map"], (tuple, list)):
                # Legacy format for dynamic tokenizers
                tokenizer_auto_map = tokenizer_config["auto_map"]
            else:
                tokenizer_auto_map = tokenizer_config["auto_map"].get("AutoTokenizer", None)
481
482
483
484

        # If that did not work, let's try to use the config.
        if config_tokenizer_class is None:
            if not isinstance(config, PretrainedConfig):
485
486
487
                config = AutoConfig.from_pretrained(
                    pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
                )
488
            config_tokenizer_class = config.tokenizer_class
489
490
            if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
                tokenizer_auto_map = config.auto_map["AutoTokenizer"]
491
492
493

        # If we have the tokenizer class from the tokenizer config or the model config we're good!
        if config_tokenizer_class is not None:
494
            tokenizer_class = None
495
496
497
498
499
500
501
502
            if tokenizer_auto_map is not None:
                if not trust_remote_code:
                    raise ValueError(
                        f"Loading {pretrained_model_name_or_path} requires you to execute the tokenizer file in that repo "
                        "on your local machine. Make sure you have read the code there to avoid malicious use, then set "
                        "the option `trust_remote_code=True` to remove this error."
                    )
                if kwargs.get("revision", None) is None:
503
                    logger.warning(
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
                        "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
                        "no malicious code has been contributed in a newer revision."
                    )

                if use_fast and tokenizer_auto_map[1] is not None:
                    class_ref = tokenizer_auto_map[1]
                else:
                    class_ref = tokenizer_auto_map[0]

                module_file, class_name = class_ref.split(".")
                tokenizer_class = get_class_from_dynamic_module(
                    pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
                )

            elif use_fast and not config_tokenizer_class.endswith("Fast"):
519
                tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
520
521
                tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
            if tokenizer_class is None:
522
                tokenizer_class_candidate = config_tokenizer_class
523
524
                tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)

525
            if tokenizer_class is None:
526
                raise ValueError(
527
                    f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
528
                )
529
530
            return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

531
        # Otherwise we have to be creative.
532
533
534
        # if model is an encoder decoder, the encoder tokenizer class is used by default
        if isinstance(config, EncoderDecoderConfig):
            if type(config.decoder) is not type(config.encoder):  # noqa: E721
535
                logger.warning(
536
                    f"The encoder model config class: {config.encoder.__class__} is different from the decoder model "
537
                    f"config class: {config.decoder.__class__}. It is not recommended to use the "
538
539
                    "`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder "
                    "specific tokenizer classes."
540
541
542
                )
            config = config.encoder

543
544
        model_type = config_class_to_model_type(type(config).__name__)
        if model_type is not None:
545
            tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
546
            if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
547
548
                return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
            else:
549
550
551
552
553
554
555
                if tokenizer_class_py is not None:
                    return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
                else:
                    raise ValueError(
                        "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed "
                        "in order to use this tokenizer."
                    )
556

557
        raise ValueError(
558
559
            f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n"
            f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING.keys())}."
560
        )
561
562
563
564
565
566
567

    def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None):
        """
        Register a new tokenizer in this mapping.


        Args:
568
            config_class ([`PretrainedConfig`]):
569
                The configuration corresponding to the model to register.
570
            slow_tokenizer_class ([`PretrainedTokenizer`], *optional*):
571
                The slow tokenizer to register.
572
            slow_tokenizer_class ([`PretrainedTokenizerFast`], *optional*):
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
                The fast tokenizer to register.
        """
        if slow_tokenizer_class is None and fast_tokenizer_class is None:
            raise ValueError("You need to pass either a `slow_tokenizer_class` or a `fast_tokenizer_class")
        if slow_tokenizer_class is not None and issubclass(slow_tokenizer_class, PreTrainedTokenizerFast):
            raise ValueError("You passed a fast tokenizer in the `slow_tokenizer_class`.")
        if fast_tokenizer_class is not None and issubclass(fast_tokenizer_class, PreTrainedTokenizer):
            raise ValueError("You passed a slow tokenizer in the `fast_tokenizer_class`.")

        if (
            slow_tokenizer_class is not None
            and fast_tokenizer_class is not None
            and issubclass(fast_tokenizer_class, PreTrainedTokenizerFast)
            and fast_tokenizer_class.slow_tokenizer_class != slow_tokenizer_class
        ):
            raise ValueError(
                "The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not "
                "consistent with the slow tokenizer class you passed (fast tokenizer has "
                f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those "
                "so they match!"
            )

        # Avoid resetting a set slow/fast tokenizer if we are passing just the other ones.
        if config_class in TOKENIZER_MAPPING._extra_content:
            existing_slow, existing_fast = TOKENIZER_MAPPING[config_class]
            if slow_tokenizer_class is None:
                slow_tokenizer_class = existing_slow
            if fast_tokenizer_class is None:
                fast_tokenizer_class = existing_fast

        TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class))