tokenization_auto.py 32.9 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" Auto Tokenizer class."""
thomwolf's avatar
thomwolf committed
16

17
import importlib
18
19
import json
import os
20
from collections import OrderedDict
21
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
thomwolf's avatar
thomwolf committed
22

Sylvain Gugger's avatar
Sylvain Gugger committed
23
from ...configuration_utils import PretrainedConfig
24
from ...dynamic_module_utils import get_class_from_dynamic_module
25
from ...tokenization_utils import PreTrainedTokenizer
26
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
27
from ...tokenization_utils_fast import PreTrainedTokenizerFast
28
from ...utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available, logging
29
30
from ..encoder_decoder import EncoderDecoderConfig
from .auto_factory import _LazyAutoMapping
31
from .configuration_auto import (
32
    CONFIG_MAPPING_NAMES,
33
    AutoConfig,
34
    config_class_to_model_type,
35
    model_type_to_module_name,
36
    replace_list_option_in_docstrings,
37
)
Aymeric Augustin's avatar
Aymeric Augustin committed
38

thomwolf's avatar
thomwolf committed
39

Lysandre Debut's avatar
Lysandre Debut committed
40
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
41

42
43
44
45
46
47
48
if TYPE_CHECKING:
    # This significantly improves completion suggestion performance when
    # the transformers package is used with Microsoft's Pylance language server.
    TOKENIZER_MAPPING_NAMES: OrderedDict[str, Tuple[Optional[str], Optional[str]]] = OrderedDict()
else:
    TOKENIZER_MAPPING_NAMES = OrderedDict(
        [
49
            (
50
                "albert",
51
                (
52
53
                    "AlbertTokenizer" if is_sentencepiece_available() else None,
                    "AlbertTokenizerFast" if is_tokenizers_available() else None,
54
                ),
55
            ),
56
            ("bart", ("BartTokenizer", "BartTokenizerFast")),
57
            (
58
                "barthez",
59
                (
60
61
                    "BarthezTokenizer" if is_sentencepiece_available() else None,
                    "BarthezTokenizerFast" if is_tokenizers_available() else None,
62
                ),
63
            ),
64
65
66
67
68
            ("bartpho", ("BartphoTokenizer", None)),
            ("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
            ("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
            ("bert-japanese", ("BertJapaneseTokenizer", None)),
            ("bertweet", ("BertweetTokenizer", None)),
69
            (
70
                "big_bird",
71
                (
72
73
                    "BigBirdTokenizer" if is_sentencepiece_available() else None,
                    "BigBirdTokenizerFast" if is_tokenizers_available() else None,
74
                ),
75
            ),
76
77
78
            ("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
            ("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
            ("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
Younes Belkada's avatar
Younes Belkada committed
79
            ("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)),
80
            ("byt5", ("ByT5Tokenizer", None)),
81
            (
82
83
84
85
86
                "camembert",
                (
                    "CamembertTokenizer" if is_sentencepiece_available() else None,
                    "CamembertTokenizerFast" if is_tokenizers_available() else None,
                ),
87
            ),
88
            ("canine", ("CanineTokenizer", None)),
89
            (
90
                "clip",
91
                (
92
93
                    "CLIPTokenizer",
                    "CLIPTokenizerFast" if is_tokenizers_available() else None,
94
                ),
95
            ),
rooa's avatar
rooa committed
96
            ("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
97
            ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
98
            (
99
                "cpm",
100
                (
101
102
                    "CpmTokenizer" if is_sentencepiece_available() else None,
                    "CpmTokenizerFast" if is_tokenizers_available() else None,
103
                ),
104
            ),
105
106
107
            ("ctrl", ("CTRLTokenizer", None)),
            ("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
            ("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
108
            (
109
                "deberta-v2",
110
                (
111
112
                    "DebertaV2Tokenizer" if is_sentencepiece_available() else None,
                    "DebertaV2TokenizerFast" if is_tokenizers_available() else None,
113
                ),
114
            ),
115
            ("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
116
            (
117
                "dpr",
118
                (
119
120
                    "DPRQuestionEncoderTokenizer",
                    "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
121
                ),
122
            ),
123
            ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
124
125
126
            ("flaubert", ("FlaubertTokenizer", None)),
            ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
            ("fsmt", ("FSMTTokenizer", None)),
127
            ("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
128
129
            ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
            ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
130
            ("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
131
            ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
132
            ("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
133
134
135
            ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
            ("hubert", ("Wav2Vec2CTCTokenizer", None)),
            ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
136
            ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
137
            ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
NielsRogge's avatar
NielsRogge committed
138
            ("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
139
            ("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
140
141
            ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
            ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
Daniel Stancl's avatar
Daniel Stancl committed
142
143
144
145
146
147
148
            (
                "longt5",
                (
                    "T5Tokenizer" if is_sentencepiece_available() else None,
                    "T5TokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
149
150
151
152
            ("luke", ("LukeTokenizer", None)),
            ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
            ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
            ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
153
            (
154
                "mbart",
155
                (
156
157
                    "MBartTokenizer" if is_sentencepiece_available() else None,
                    "MBartTokenizerFast" if is_tokenizers_available() else None,
158
                ),
159
160
            ),
            (
161
                "mbart50",
162
                (
163
164
                    "MBart50Tokenizer" if is_sentencepiece_available() else None,
                    "MBart50TokenizerFast" if is_tokenizers_available() else None,
165
                ),
166
            ),
167
168
169
170
            ("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
            ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
            ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
            ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
171
            (
172
                "mt5",
173
                (
174
175
                    "MT5Tokenizer" if is_sentencepiece_available() else None,
                    "MT5TokenizerFast" if is_tokenizers_available() else None,
176
177
                ),
            ),
178
            ("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
179
            (
180
                "nystromformer",
181
                (
182
183
                    "AlbertTokenizer" if is_sentencepiece_available() else None,
                    "AlbertTokenizerFast" if is_tokenizers_available() else None,
184
                ),
185
            ),
186
187
            ("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
            ("opt", ("GPT2Tokenizer", None)),
188
            (
189
                "pegasus",
190
                (
191
192
                    "PegasusTokenizer" if is_sentencepiece_available() else None,
                    "PegasusTokenizerFast" if is_tokenizers_available() else None,
193
                ),
194
            ),
195
            (
196
                "perceiver",
197
                (
198
199
                    "PerceiverTokenizer",
                    None,
200
201
                ),
            ),
202
203
204
205
206
207
            ("phobert", ("PhobertTokenizer", None)),
            ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
            ("prophetnet", ("ProphetNetTokenizer", None)),
            ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
            ("rag", ("RagTokenizer", None)),
            ("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
208
            (
209
                "reformer",
210
                (
211
212
                    "ReformerTokenizer" if is_sentencepiece_available() else None,
                    "ReformerTokenizerFast" if is_tokenizers_available() else None,
213
214
                ),
            ),
215
216
217
218
219
220
221
            (
                "rembert",
                (
                    "RemBertTokenizer" if is_sentencepiece_available() else None,
                    "RemBertTokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
222
223
224
225
226
227
            ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
            ("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
            ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
            ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
            ("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
            ("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
228
            (
229
230
                "squeezebert",
                ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
231
            ),
232
            (
233
                "t5",
234
                (
235
236
                    "T5Tokenizer" if is_sentencepiece_available() else None,
                    "T5TokenizerFast" if is_tokenizers_available() else None,
237
238
                ),
            ),
239
240
241
            ("tapas", ("TapasTokenizer", None)),
            ("tapex", ("TapexTokenizer", None)),
            ("transfo-xl", ("TransfoXLTokenizer", None)),
242
            ("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
243
244
            ("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
            ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
245
            ("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
246
            ("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
247
248
249
250
251
252
253
            (
                "xglm",
                (
                    "XGLMTokenizer" if is_sentencepiece_available() else None,
                    "XGLMTokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
254
255
            ("xlm", ("XLMTokenizer", None)),
            ("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
256
            (
257
                "xlm-roberta",
258
                (
259
260
                    "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
                    "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
261
262
263
                ),
            ),
            ("xlm-roberta-xl", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
264
265
266
267
268
269
270
            (
                "xlnet",
                (
                    "XLNetTokenizer" if is_sentencepiece_available() else None,
                    "XLNetTokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
271
272
273
274
275
276
277
            (
                "yoso",
                (
                    "AlbertTokenizer" if is_sentencepiece_available() else None,
                    "AlbertTokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
278
279
        ]
    )
280

281
282
283
TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)

CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
284

285

286
def tokenizer_class_from_name(class_name: str):
287
288
289
290
291
    if class_name == "PreTrainedTokenizerFast":
        return PreTrainedTokenizerFast

    for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
        if class_name in tokenizers:
292
            module_name = model_type_to_module_name(module_name)
293

294
            module = importlib.import_module(f".{module_name}", "transformers.models")
295
296
297
298
            try:
                return getattr(module, class_name)
            except AttributeError:
                continue
299

300
301
302
303
304
    for config, tokenizers in TOKENIZER_MAPPING._extra_content.items():
        for tokenizer in tokenizers:
            if getattr(tokenizer, "__name__", None) == class_name:
                return tokenizer

305
306
307
308
309
310
    # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
    # init and we return the proper dummy to get an appropriate error message.
    main_module = importlib.import_module("transformers")
    if hasattr(main_module, class_name):
        return getattr(main_module, class_name)

311
    return None
312
313


314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
def get_tokenizer_config(
    pretrained_model_name_or_path: Union[str, os.PathLike],
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    force_download: bool = False,
    resume_download: bool = False,
    proxies: Optional[Dict[str, str]] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
    revision: Optional[str] = None,
    local_files_only: bool = False,
    **kwargs,
):
    """
    Loads the tokenizer configuration from a pretrained model tokenizer configuration.

    Args:
329
        pretrained_model_name_or_path (`str` or `os.PathLike`):
330
331
            This can be either:

332
            - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
Sylvain Gugger's avatar
Sylvain Gugger committed
333
334
              huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
              under a user or organization name, like `dbmdz/bert-base-german-cased`.
335
336
            - a path to a *directory* containing a configuration file saved using the
              [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
337

338
        cache_dir (`str` or `os.PathLike`, *optional*):
339
340
            Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
            cache should not be used.
341
        force_download (`bool`, *optional*, defaults to `False`):
342
343
            Whether or not to force to (re-)download the configuration files and override the cached versions if they
            exist.
344
        resume_download (`bool`, *optional*, defaults to `False`):
345
            Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
346
        proxies (`Dict[str, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
347
348
            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
            'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
349
        use_auth_token (`str` or *bool*, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
350
351
            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
            when running `transformers-cli login` (stored in `~/.huggingface`).
352
        revision (`str`, *optional*, defaults to `"main"`):
353
            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
354
            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
355
            identifier allowed by git.
356
357
        local_files_only (`bool`, *optional*, defaults to `False`):
            If `True`, will only try to load the tokenizer configuration from local files.
358

359
    <Tip>
360

361
    Passing `use_auth_token=True` is required when you want to use a private model.
362

363
    </Tip>
364
365

    Returns:
366
        `Dict`: The configuration of the tokenizer.
367

368
    Examples:
369

370
371
372
373
374
    ```python
    # Download configuration from huggingface.co and cache.
    tokenizer_config = get_tokenizer_config("bert-base-uncased")
    # This model does not have a tokenizer config so the result will be an empty dict.
    tokenizer_config = get_tokenizer_config("xlm-roberta-base")
375

376
377
    # Save a pretrained tokenizer locally and you can reload its config
    from transformers import AutoTokenizer
378

379
380
381
382
    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    tokenizer.save_pretrained("tokenizer-test")
    tokenizer_config = get_tokenizer_config("tokenizer-test")
    ```"""
383
384
385
386
387
388
389
390
391
392
393
394
    resolved_config_file = get_file_from_repo(
        pretrained_model_name_or_path,
        TOKENIZER_CONFIG_FILE,
        cache_dir=cache_dir,
        force_download=force_download,
        resume_download=resume_download,
        proxies=proxies,
        use_auth_token=use_auth_token,
        revision=revision,
        local_files_only=local_files_only,
    )
    if resolved_config_file is None:
395
396
397
398
399
400
401
        logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
        return {}

    with open(resolved_config_file, encoding="utf-8") as reader:
        return json.load(reader)


Julien Chaumond's avatar
Julien Chaumond committed
402
class AutoTokenizer:
403
    r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
404
    This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
405
    created with the [`AutoTokenizer.from_pretrained`] class method.
thomwolf's avatar
thomwolf committed
406

407
    This class cannot be instantiated directly using `__init__()` (throws an error).
thomwolf's avatar
thomwolf committed
408
    """
409

thomwolf's avatar
thomwolf committed
410
    def __init__(self):
411
412
413
414
        raise EnvironmentError(
            "AutoTokenizer is designed to be instantiated "
            "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
        )
thomwolf's avatar
thomwolf committed
415
416

    @classmethod
417
    @replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
thomwolf's avatar
thomwolf committed
418
    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
419
420
        r"""
        Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.
thomwolf's avatar
thomwolf committed
421

Sylvain Gugger's avatar
Sylvain Gugger committed
422
423
424
        The tokenizer class to instantiate is selected based on the `model_type` property of the config object (either
        passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
        falling back to using pattern matching on `pretrained_model_name_or_path`:
425

426
        List options
thomwolf's avatar
thomwolf committed
427
428

        Params:
429
            pretrained_model_name_or_path (`str` or `os.PathLike`):
430
431
                Can be either:

432
                    - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
Sylvain Gugger's avatar
Sylvain Gugger committed
433
434
                      Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
                      user or organization name, like `dbmdz/bert-base-german-cased`.
435
                    - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
Sylvain Gugger's avatar
Sylvain Gugger committed
436
                      using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
437
                    - A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
438
                      single vocabulary file (like Bert or XLNet), e.g.: `./my_model_directory/vocab.txt`. (Not
Sylvain Gugger's avatar
Sylvain Gugger committed
439
                      applicable to all derived classes)
440
441
442
            inputs (additional positional arguments, *optional*):
                Will be passed along to the Tokenizer `__init__()` method.
            config ([`PretrainedConfig`], *optional*)
443
                The configuration object used to dertermine the tokenizer class to instantiate.
444
            cache_dir (`str` or `os.PathLike`, *optional*):
445
446
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
447
            force_download (`bool`, *optional*, defaults to `False`):
448
449
                Whether or not to force the (re-)download the model weights and configuration files and override the
                cached versions if they exist.
450
            resume_download (`bool`, *optional*, defaults to `False`):
451
452
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
453
            proxies (`Dict[str, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
454
455
                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
456
            revision (`str`, *optional*, defaults to `"main"`):
Julien Chaumond's avatar
Julien Chaumond committed
457
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
458
                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
Julien Chaumond's avatar
Julien Chaumond committed
459
                identifier allowed by git.
460
            subfolder (`str`, *optional*):
461
462
                In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
                facebook/rag-token-base), specify it here.
463
            use_fast (`bool`, *optional*, defaults to `True`):
464
                Whether or not to try to load the fast version of the tokenizer.
465
            tokenizer_type (`str`, *optional*):
466
                Tokenizer type to be loaded.
467
            trust_remote_code (`bool`, *optional*, defaults to `False`):
468
                Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
Sylvain Gugger's avatar
Sylvain Gugger committed
469
470
                should only be set to `True` for repositories you trust and in which you have read the code, as it will
                execute code present on the Hub on your local machine.
471
472
            kwargs (additional keyword arguments, *optional*):
                Will be passed to the Tokenizer `__init__()` method. Can be used to set special tokens like
Sylvain Gugger's avatar
Sylvain Gugger committed
473
474
                `bos_token`, `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`,
                `additional_special_tokens`. See parameters in the `__init__()` for more details.
thomwolf's avatar
thomwolf committed
475

476
        Examples:
477

478
479
        ```python
        >>> from transformers import AutoTokenizer
480

481
        >>> # Download vocabulary from huggingface.co and cache.
Sylvain Gugger's avatar
Sylvain Gugger committed
482
        >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
483

484
        >>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
Sylvain Gugger's avatar
Sylvain Gugger committed
485
        >>> tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
thomwolf's avatar
thomwolf committed
486

487
        >>> # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
Sylvain Gugger's avatar
Sylvain Gugger committed
488
        >>> tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")
489
490
491

        >>> # Download vocabulary from huggingface.co and define model-specific arguments
        >>> tokenizer = AutoTokenizer.from_pretrained("roberta-base", add_prefix_space=True)
492
        ```"""
493
        config = kwargs.pop("config", None)
494
        kwargs["_from_auto"] = True
495

496
        use_fast = kwargs.pop("use_fast", True)
497
        tokenizer_type = kwargs.pop("tokenizer_type", None)
498
        trust_remote_code = kwargs.pop("trust_remote_code", False)
499

500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        # First, let's see whether the tokenizer_type is passed so that we can leverage it
        if tokenizer_type is not None:
            tokenizer_class = None
            tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None)

            if tokenizer_class_tuple is None:
                raise ValueError(
                    f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of "
                    f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES.keys())}."
                )

            tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple

            if use_fast and tokenizer_fast_class_name is not None:
                tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name)

            if tokenizer_class is None:
                tokenizer_class = tokenizer_class_from_name(tokenizer_class_name)

            if tokenizer_class is None:
                raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.")

            return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

        # Next, let's try to use the tokenizer_config file to get the tokenizer class.
525
526
        tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
        config_tokenizer_class = tokenizer_config.get("tokenizer_class")
527
528
529
530
531
532
533
        tokenizer_auto_map = None
        if "auto_map" in tokenizer_config:
            if isinstance(tokenizer_config["auto_map"], (tuple, list)):
                # Legacy format for dynamic tokenizers
                tokenizer_auto_map = tokenizer_config["auto_map"]
            else:
                tokenizer_auto_map = tokenizer_config["auto_map"].get("AutoTokenizer", None)
534
535
536
537

        # If that did not work, let's try to use the config.
        if config_tokenizer_class is None:
            if not isinstance(config, PretrainedConfig):
538
539
540
                config = AutoConfig.from_pretrained(
                    pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
                )
541
            config_tokenizer_class = config.tokenizer_class
542
543
            if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
                tokenizer_auto_map = config.auto_map["AutoTokenizer"]
544
545
546

        # If we have the tokenizer class from the tokenizer config or the model config we're good!
        if config_tokenizer_class is not None:
547
            tokenizer_class = None
548
549
550
            if tokenizer_auto_map is not None:
                if not trust_remote_code:
                    raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
551
552
553
                        f"Loading {pretrained_model_name_or_path} requires you to execute the tokenizer file in that"
                        " repo on your local machine. Make sure you have read the code there to avoid malicious use,"
                        " then set the option `trust_remote_code=True` to remove this error."
554
555
                    )
                if kwargs.get("revision", None) is None:
556
                    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
557
558
                        "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure"
                        " no malicious code has been contributed in a newer revision."
559
560
561
562
563
564
565
566
567
568
569
570
571
                    )

                if use_fast and tokenizer_auto_map[1] is not None:
                    class_ref = tokenizer_auto_map[1]
                else:
                    class_ref = tokenizer_auto_map[0]

                module_file, class_name = class_ref.split(".")
                tokenizer_class = get_class_from_dynamic_module(
                    pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
                )

            elif use_fast and not config_tokenizer_class.endswith("Fast"):
572
                tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
573
574
                tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
            if tokenizer_class is None:
575
                tokenizer_class_candidate = config_tokenizer_class
576
577
                tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)

578
            if tokenizer_class is None:
579
                raise ValueError(
580
                    f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
581
                )
582
583
            return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

584
        # Otherwise we have to be creative.
585
586
587
        # if model is an encoder decoder, the encoder tokenizer class is used by default
        if isinstance(config, EncoderDecoderConfig):
            if type(config.decoder) is not type(config.encoder):  # noqa: E721
588
                logger.warning(
589
                    f"The encoder model config class: {config.encoder.__class__} is different from the decoder model "
590
                    f"config class: {config.decoder.__class__}. It is not recommended to use the "
591
592
                    "`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder "
                    "specific tokenizer classes."
593
594
595
                )
            config = config.encoder

596
597
        model_type = config_class_to_model_type(type(config).__name__)
        if model_type is not None:
598
            tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
599
            if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
600
601
                return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
            else:
602
603
604
605
606
607
608
                if tokenizer_class_py is not None:
                    return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
                else:
                    raise ValueError(
                        "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed "
                        "in order to use this tokenizer."
                    )
609

610
        raise ValueError(
611
612
            f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n"
            f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING.keys())}."
613
        )
614
615
616
617
618
619
620

    def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None):
        """
        Register a new tokenizer in this mapping.


        Args:
621
            config_class ([`PretrainedConfig`]):
622
                The configuration corresponding to the model to register.
623
            slow_tokenizer_class ([`PretrainedTokenizer`], *optional*):
624
                The slow tokenizer to register.
625
            slow_tokenizer_class ([`PretrainedTokenizerFast`], *optional*):
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
                The fast tokenizer to register.
        """
        if slow_tokenizer_class is None and fast_tokenizer_class is None:
            raise ValueError("You need to pass either a `slow_tokenizer_class` or a `fast_tokenizer_class")
        if slow_tokenizer_class is not None and issubclass(slow_tokenizer_class, PreTrainedTokenizerFast):
            raise ValueError("You passed a fast tokenizer in the `slow_tokenizer_class`.")
        if fast_tokenizer_class is not None and issubclass(fast_tokenizer_class, PreTrainedTokenizer):
            raise ValueError("You passed a slow tokenizer in the `fast_tokenizer_class`.")

        if (
            slow_tokenizer_class is not None
            and fast_tokenizer_class is not None
            and issubclass(fast_tokenizer_class, PreTrainedTokenizerFast)
            and fast_tokenizer_class.slow_tokenizer_class != slow_tokenizer_class
        ):
            raise ValueError(
                "The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not "
                "consistent with the slow tokenizer class you passed (fast tokenizer has "
                f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those "
                "so they match!"
            )

        # Avoid resetting a set slow/fast tokenizer if we are passing just the other ones.
        if config_class in TOKENIZER_MAPPING._extra_content:
            existing_slow, existing_fast = TOKENIZER_MAPPING[config_class]
            if slow_tokenizer_class is None:
                slow_tokenizer_class = existing_slow
            if fast_tokenizer_class is None:
                fast_tokenizer_class = existing_fast

        TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class))