tokenization_auto.py 33.9 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" Auto Tokenizer class."""
thomwolf's avatar
thomwolf committed
16

17
import importlib
18
19
import json
import os
20
from collections import OrderedDict
21
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
thomwolf's avatar
thomwolf committed
22

Sylvain Gugger's avatar
Sylvain Gugger committed
23
from ...configuration_utils import PretrainedConfig
24
from ...dynamic_module_utils import get_class_from_dynamic_module
25
from ...tokenization_utils import PreTrainedTokenizer
26
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
27
from ...tokenization_utils_fast import PreTrainedTokenizerFast
28
from ...utils import cached_file, extract_commit_hash, is_sentencepiece_available, is_tokenizers_available, logging
29
30
from ..encoder_decoder import EncoderDecoderConfig
from .auto_factory import _LazyAutoMapping
31
from .configuration_auto import (
32
    CONFIG_MAPPING_NAMES,
33
    AutoConfig,
34
    config_class_to_model_type,
35
    model_type_to_module_name,
36
    replace_list_option_in_docstrings,
37
)
Aymeric Augustin's avatar
Aymeric Augustin committed
38

thomwolf's avatar
thomwolf committed
39

Lysandre Debut's avatar
Lysandre Debut committed
40
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
41

42
43
44
45
46
47
48
if TYPE_CHECKING:
    # This significantly improves completion suggestion performance when
    # the transformers package is used with Microsoft's Pylance language server.
    TOKENIZER_MAPPING_NAMES: OrderedDict[str, Tuple[Optional[str], Optional[str]]] = OrderedDict()
else:
    TOKENIZER_MAPPING_NAMES = OrderedDict(
        [
49
            (
50
                "albert",
51
                (
52
53
                    "AlbertTokenizer" if is_sentencepiece_available() else None,
                    "AlbertTokenizerFast" if is_tokenizers_available() else None,
54
                ),
55
            ),
56
            ("bart", ("BartTokenizer", "BartTokenizerFast")),
57
            (
58
                "barthez",
59
                (
60
61
                    "BarthezTokenizer" if is_sentencepiece_available() else None,
                    "BarthezTokenizerFast" if is_tokenizers_available() else None,
62
                ),
63
            ),
64
65
66
67
68
            ("bartpho", ("BartphoTokenizer", None)),
            ("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
            ("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
            ("bert-japanese", ("BertJapaneseTokenizer", None)),
            ("bertweet", ("BertweetTokenizer", None)),
69
            (
70
                "big_bird",
71
                (
72
73
                    "BigBirdTokenizer" if is_sentencepiece_available() else None,
                    "BigBirdTokenizerFast" if is_tokenizers_available() else None,
74
                ),
75
            ),
76
77
78
            ("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
            ("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
            ("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
Younes Belkada's avatar
Younes Belkada committed
79
            ("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)),
80
            ("byt5", ("ByT5Tokenizer", None)),
81
            (
82
83
84
85
86
                "camembert",
                (
                    "CamembertTokenizer" if is_sentencepiece_available() else None,
                    "CamembertTokenizerFast" if is_tokenizers_available() else None,
                ),
87
            ),
88
            ("canine", ("CanineTokenizer", None)),
89
            (
90
                "clip",
91
                (
92
93
                    "CLIPTokenizer",
                    "CLIPTokenizerFast" if is_tokenizers_available() else None,
94
                ),
95
            ),
rooa's avatar
rooa committed
96
            ("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
97
            ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
98
            (
99
                "cpm",
100
                (
101
102
                    "CpmTokenizer" if is_sentencepiece_available() else None,
                    "CpmTokenizerFast" if is_tokenizers_available() else None,
103
                ),
104
            ),
105
106
107
            ("ctrl", ("CTRLTokenizer", None)),
            ("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
            ("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
108
            (
109
                "deberta-v2",
110
                (
111
112
                    "DebertaV2Tokenizer" if is_sentencepiece_available() else None,
                    "DebertaV2TokenizerFast" if is_tokenizers_available() else None,
113
                ),
114
            ),
115
            ("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
116
            (
117
                "dpr",
118
                (
119
120
                    "DPRQuestionEncoderTokenizer",
                    "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
121
                ),
122
            ),
123
            ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
124
            ("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
125
126
127
            ("flaubert", ("FlaubertTokenizer", None)),
            ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
            ("fsmt", ("FSMTTokenizer", None)),
128
            ("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
129
130
            ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
            ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
131
            ("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
132
            ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
133
            ("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
134
135
136
            ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
            ("hubert", ("Wav2Vec2CTCTokenizer", None)),
            ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
137
            ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
138
            ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
NielsRogge's avatar
NielsRogge committed
139
            ("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
140
            ("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
141
142
            ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
            ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
Daniel Stancl's avatar
Daniel Stancl committed
143
144
145
146
147
148
149
            (
                "longt5",
                (
                    "T5Tokenizer" if is_sentencepiece_available() else None,
                    "T5TokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
150
151
152
153
            ("luke", ("LukeTokenizer", None)),
            ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
            ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
            ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
154
            (
155
                "mbart",
156
                (
157
158
                    "MBartTokenizer" if is_sentencepiece_available() else None,
                    "MBartTokenizerFast" if is_tokenizers_available() else None,
159
                ),
160
161
            ),
            (
162
                "mbart50",
163
                (
164
165
                    "MBart50Tokenizer" if is_sentencepiece_available() else None,
                    "MBart50TokenizerFast" if is_tokenizers_available() else None,
166
                ),
167
            ),
168
169
170
171
            ("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
            ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
            ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
            ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
172
            (
173
                "mt5",
174
                (
175
176
                    "MT5Tokenizer" if is_sentencepiece_available() else None,
                    "MT5TokenizerFast" if is_tokenizers_available() else None,
177
178
                ),
            ),
StevenTang1998's avatar
StevenTang1998 committed
179
            ("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)),
180
            ("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
181
            (
Lysandre Debut's avatar
Lysandre Debut committed
182
183
184
185
186
187
188
                "nllb",
                (
                    "NllbTokenizer" if is_sentencepiece_available() else None,
                    "NllbTokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
            (
189
                "nystromformer",
190
                (
191
192
                    "AlbertTokenizer" if is_sentencepiece_available() else None,
                    "AlbertTokenizerFast" if is_tokenizers_available() else None,
193
                ),
194
            ),
195
196
            ("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
            ("opt", ("GPT2Tokenizer", None)),
197
            ("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
198
            (
199
                "pegasus",
200
                (
201
202
                    "PegasusTokenizer" if is_sentencepiece_available() else None,
                    "PegasusTokenizerFast" if is_tokenizers_available() else None,
203
                ),
204
            ),
205
            (
206
                "perceiver",
207
                (
208
209
                    "PerceiverTokenizer",
                    None,
210
211
                ),
            ),
212
213
214
215
216
217
            ("phobert", ("PhobertTokenizer", None)),
            ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
            ("prophetnet", ("ProphetNetTokenizer", None)),
            ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
            ("rag", ("RagTokenizer", None)),
            ("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
218
            (
219
                "reformer",
220
                (
221
222
                    "ReformerTokenizer" if is_sentencepiece_available() else None,
                    "ReformerTokenizerFast" if is_tokenizers_available() else None,
223
224
                ),
            ),
225
226
227
228
229
230
231
            (
                "rembert",
                (
                    "RemBertTokenizer" if is_sentencepiece_available() else None,
                    "RemBertTokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
232
233
234
235
236
237
            ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
            ("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
            ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
            ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
            ("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
            ("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
238
            (
239
240
                "squeezebert",
                ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
241
            ),
242
            (
243
                "t5",
244
                (
245
246
                    "T5Tokenizer" if is_sentencepiece_available() else None,
                    "T5TokenizerFast" if is_tokenizers_available() else None,
247
248
                ),
            ),
249
250
251
            ("tapas", ("TapasTokenizer", None)),
            ("tapex", ("TapexTokenizer", None)),
            ("transfo-xl", ("TransfoXLTokenizer", None)),
252
            ("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
253
254
            ("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
            ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
255
            ("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
256
            ("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
NielsRogge's avatar
NielsRogge committed
257
            ("xclip", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
258
259
260
261
262
263
264
            (
                "xglm",
                (
                    "XGLMTokenizer" if is_sentencepiece_available() else None,
                    "XGLMTokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
265
266
            ("xlm", ("XLMTokenizer", None)),
            ("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
267
            (
268
                "xlm-roberta",
269
                (
270
271
                    "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
                    "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
272
273
274
                ),
            ),
            ("xlm-roberta-xl", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
275
276
277
278
279
280
281
            (
                "xlnet",
                (
                    "XLNetTokenizer" if is_sentencepiece_available() else None,
                    "XLNetTokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
282
283
284
285
286
287
288
            (
                "yoso",
                (
                    "AlbertTokenizer" if is_sentencepiece_available() else None,
                    "AlbertTokenizerFast" if is_tokenizers_available() else None,
                ),
            ),
289
290
        ]
    )
291

292
293
294
TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)

CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
295

296

297
def tokenizer_class_from_name(class_name: str):
298
299
300
301
302
    if class_name == "PreTrainedTokenizerFast":
        return PreTrainedTokenizerFast

    for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
        if class_name in tokenizers:
303
            module_name = model_type_to_module_name(module_name)
304

305
            module = importlib.import_module(f".{module_name}", "transformers.models")
306
307
308
309
            try:
                return getattr(module, class_name)
            except AttributeError:
                continue
310

311
312
313
314
315
    for config, tokenizers in TOKENIZER_MAPPING._extra_content.items():
        for tokenizer in tokenizers:
            if getattr(tokenizer, "__name__", None) == class_name:
                return tokenizer

316
317
318
319
320
321
    # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
    # init and we return the proper dummy to get an appropriate error message.
    main_module = importlib.import_module("transformers")
    if hasattr(main_module, class_name):
        return getattr(main_module, class_name)

322
    return None
323
324


325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def get_tokenizer_config(
    pretrained_model_name_or_path: Union[str, os.PathLike],
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    force_download: bool = False,
    resume_download: bool = False,
    proxies: Optional[Dict[str, str]] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
    revision: Optional[str] = None,
    local_files_only: bool = False,
    **kwargs,
):
    """
    Loads the tokenizer configuration from a pretrained model tokenizer configuration.

    Args:
340
        pretrained_model_name_or_path (`str` or `os.PathLike`):
341
342
            This can be either:

343
            - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
Sylvain Gugger's avatar
Sylvain Gugger committed
344
345
              huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
              under a user or organization name, like `dbmdz/bert-base-german-cased`.
346
347
            - a path to a *directory* containing a configuration file saved using the
              [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
348

349
        cache_dir (`str` or `os.PathLike`, *optional*):
350
351
            Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
            cache should not be used.
352
        force_download (`bool`, *optional*, defaults to `False`):
353
354
            Whether or not to force to (re-)download the configuration files and override the cached versions if they
            exist.
355
        resume_download (`bool`, *optional*, defaults to `False`):
356
            Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
357
        proxies (`Dict[str, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
358
359
            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
            'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
360
        use_auth_token (`str` or *bool*, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
361
            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
362
            when running `huggingface-cli login` (stored in `~/.huggingface`).
363
        revision (`str`, *optional*, defaults to `"main"`):
364
            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
365
            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
366
            identifier allowed by git.
367
368
        local_files_only (`bool`, *optional*, defaults to `False`):
            If `True`, will only try to load the tokenizer configuration from local files.
369

370
    <Tip>
371

372
    Passing `use_auth_token=True` is required when you want to use a private model.
373

374
    </Tip>
375
376

    Returns:
377
        `Dict`: The configuration of the tokenizer.
378

379
    Examples:
380

381
382
383
384
385
    ```python
    # Download configuration from huggingface.co and cache.
    tokenizer_config = get_tokenizer_config("bert-base-uncased")
    # This model does not have a tokenizer config so the result will be an empty dict.
    tokenizer_config = get_tokenizer_config("xlm-roberta-base")
386

387
388
    # Save a pretrained tokenizer locally and you can reload its config
    from transformers import AutoTokenizer
389

390
391
392
393
    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    tokenizer.save_pretrained("tokenizer-test")
    tokenizer_config = get_tokenizer_config("tokenizer-test")
    ```"""
394
395
    commit_hash = kwargs.get("_commit_hash", None)
    resolved_config_file = cached_file(
396
397
398
399
400
401
402
403
404
        pretrained_model_name_or_path,
        TOKENIZER_CONFIG_FILE,
        cache_dir=cache_dir,
        force_download=force_download,
        resume_download=resume_download,
        proxies=proxies,
        use_auth_token=use_auth_token,
        revision=revision,
        local_files_only=local_files_only,
405
406
407
        _raise_exceptions_for_missing_entries=False,
        _raise_exceptions_for_connection_errors=False,
        _commit_hash=commit_hash,
408
409
    )
    if resolved_config_file is None:
410
411
        logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
        return {}
412
    commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
413
414

    with open(resolved_config_file, encoding="utf-8") as reader:
415
416
417
        result = json.load(reader)
    result["_commit_hash"] = commit_hash
    return result
418
419


Julien Chaumond's avatar
Julien Chaumond committed
420
class AutoTokenizer:
421
    r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
422
    This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
423
    created with the [`AutoTokenizer.from_pretrained`] class method.
thomwolf's avatar
thomwolf committed
424

425
    This class cannot be instantiated directly using `__init__()` (throws an error).
thomwolf's avatar
thomwolf committed
426
    """
427

thomwolf's avatar
thomwolf committed
428
    def __init__(self):
429
430
431
432
        raise EnvironmentError(
            "AutoTokenizer is designed to be instantiated "
            "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
        )
thomwolf's avatar
thomwolf committed
433
434

    @classmethod
435
    @replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
thomwolf's avatar
thomwolf committed
436
    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
437
438
        r"""
        Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.
thomwolf's avatar
thomwolf committed
439

Sylvain Gugger's avatar
Sylvain Gugger committed
440
441
442
        The tokenizer class to instantiate is selected based on the `model_type` property of the config object (either
        passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
        falling back to using pattern matching on `pretrained_model_name_or_path`:
443

444
        List options
thomwolf's avatar
thomwolf committed
445
446

        Params:
447
            pretrained_model_name_or_path (`str` or `os.PathLike`):
448
449
                Can be either:

450
                    - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
Sylvain Gugger's avatar
Sylvain Gugger committed
451
452
                      Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
                      user or organization name, like `dbmdz/bert-base-german-cased`.
453
                    - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
Sylvain Gugger's avatar
Sylvain Gugger committed
454
                      using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
455
                    - A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
456
                      single vocabulary file (like Bert or XLNet), e.g.: `./my_model_directory/vocab.txt`. (Not
Sylvain Gugger's avatar
Sylvain Gugger committed
457
                      applicable to all derived classes)
458
459
460
            inputs (additional positional arguments, *optional*):
                Will be passed along to the Tokenizer `__init__()` method.
            config ([`PretrainedConfig`], *optional*)
461
                The configuration object used to dertermine the tokenizer class to instantiate.
462
            cache_dir (`str` or `os.PathLike`, *optional*):
463
464
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
465
            force_download (`bool`, *optional*, defaults to `False`):
466
467
                Whether or not to force the (re-)download the model weights and configuration files and override the
                cached versions if they exist.
468
            resume_download (`bool`, *optional*, defaults to `False`):
469
470
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
471
            proxies (`Dict[str, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
472
473
                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
474
            revision (`str`, *optional*, defaults to `"main"`):
Julien Chaumond's avatar
Julien Chaumond committed
475
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
476
                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
Julien Chaumond's avatar
Julien Chaumond committed
477
                identifier allowed by git.
478
            subfolder (`str`, *optional*):
479
480
                In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
                facebook/rag-token-base), specify it here.
481
            use_fast (`bool`, *optional*, defaults to `True`):
482
                Whether or not to try to load the fast version of the tokenizer.
483
            tokenizer_type (`str`, *optional*):
484
                Tokenizer type to be loaded.
485
            trust_remote_code (`bool`, *optional*, defaults to `False`):
486
                Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
Sylvain Gugger's avatar
Sylvain Gugger committed
487
488
                should only be set to `True` for repositories you trust and in which you have read the code, as it will
                execute code present on the Hub on your local machine.
489
490
            kwargs (additional keyword arguments, *optional*):
                Will be passed to the Tokenizer `__init__()` method. Can be used to set special tokens like
Sylvain Gugger's avatar
Sylvain Gugger committed
491
492
                `bos_token`, `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`,
                `additional_special_tokens`. See parameters in the `__init__()` for more details.
thomwolf's avatar
thomwolf committed
493

494
        Examples:
495

496
497
        ```python
        >>> from transformers import AutoTokenizer
498

499
        >>> # Download vocabulary from huggingface.co and cache.
Sylvain Gugger's avatar
Sylvain Gugger committed
500
        >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
501

502
        >>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
Sylvain Gugger's avatar
Sylvain Gugger committed
503
        >>> tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
thomwolf's avatar
thomwolf committed
504

505
        >>> # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
Sylvain Gugger's avatar
Sylvain Gugger committed
506
        >>> tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")
507
508
509

        >>> # Download vocabulary from huggingface.co and define model-specific arguments
        >>> tokenizer = AutoTokenizer.from_pretrained("roberta-base", add_prefix_space=True)
510
        ```"""
511
        config = kwargs.pop("config", None)
512
        kwargs["_from_auto"] = True
513

514
        use_fast = kwargs.pop("use_fast", True)
515
        tokenizer_type = kwargs.pop("tokenizer_type", None)
516
        trust_remote_code = kwargs.pop("trust_remote_code", False)
517

518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
        # First, let's see whether the tokenizer_type is passed so that we can leverage it
        if tokenizer_type is not None:
            tokenizer_class = None
            tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None)

            if tokenizer_class_tuple is None:
                raise ValueError(
                    f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of "
                    f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES.keys())}."
                )

            tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple

            if use_fast and tokenizer_fast_class_name is not None:
                tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name)

            if tokenizer_class is None:
                tokenizer_class = tokenizer_class_from_name(tokenizer_class_name)

            if tokenizer_class is None:
                raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.")

            return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

        # Next, let's try to use the tokenizer_config file to get the tokenizer class.
543
        tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
544
545
        if "_commit_hash" in tokenizer_config:
            kwargs["_commit_hash"] = tokenizer_config["_commit_hash"]
546
        config_tokenizer_class = tokenizer_config.get("tokenizer_class")
547
548
549
550
551
552
553
        tokenizer_auto_map = None
        if "auto_map" in tokenizer_config:
            if isinstance(tokenizer_config["auto_map"], (tuple, list)):
                # Legacy format for dynamic tokenizers
                tokenizer_auto_map = tokenizer_config["auto_map"]
            else:
                tokenizer_auto_map = tokenizer_config["auto_map"].get("AutoTokenizer", None)
554
555
556
557

        # If that did not work, let's try to use the config.
        if config_tokenizer_class is None:
            if not isinstance(config, PretrainedConfig):
558
559
560
                config = AutoConfig.from_pretrained(
                    pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
                )
561
            config_tokenizer_class = config.tokenizer_class
562
563
            if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
                tokenizer_auto_map = config.auto_map["AutoTokenizer"]
564
565
566

        # If we have the tokenizer class from the tokenizer config or the model config we're good!
        if config_tokenizer_class is not None:
567
            tokenizer_class = None
568
569
570
            if tokenizer_auto_map is not None:
                if not trust_remote_code:
                    raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
571
572
573
                        f"Loading {pretrained_model_name_or_path} requires you to execute the tokenizer file in that"
                        " repo on your local machine. Make sure you have read the code there to avoid malicious use,"
                        " then set the option `trust_remote_code=True` to remove this error."
574
575
                    )
                if kwargs.get("revision", None) is None:
576
                    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
577
578
                        "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure"
                        " no malicious code has been contributed in a newer revision."
579
580
581
582
583
584
585
586
587
588
589
590
591
                    )

                if use_fast and tokenizer_auto_map[1] is not None:
                    class_ref = tokenizer_auto_map[1]
                else:
                    class_ref = tokenizer_auto_map[0]

                module_file, class_name = class_ref.split(".")
                tokenizer_class = get_class_from_dynamic_module(
                    pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
                )

            elif use_fast and not config_tokenizer_class.endswith("Fast"):
592
                tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
593
594
                tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
            if tokenizer_class is None:
595
                tokenizer_class_candidate = config_tokenizer_class
596
597
                tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)

598
            if tokenizer_class is None:
599
                raise ValueError(
600
                    f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
601
                )
602
603
            return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

604
        # Otherwise we have to be creative.
605
606
607
        # if model is an encoder decoder, the encoder tokenizer class is used by default
        if isinstance(config, EncoderDecoderConfig):
            if type(config.decoder) is not type(config.encoder):  # noqa: E721
608
                logger.warning(
609
                    f"The encoder model config class: {config.encoder.__class__} is different from the decoder model "
610
                    f"config class: {config.decoder.__class__}. It is not recommended to use the "
611
612
                    "`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder "
                    "specific tokenizer classes."
613
614
615
                )
            config = config.encoder

616
617
        model_type = config_class_to_model_type(type(config).__name__)
        if model_type is not None:
618
            tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
619
            if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
620
621
                return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
            else:
622
623
624
625
626
627
628
                if tokenizer_class_py is not None:
                    return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
                else:
                    raise ValueError(
                        "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed "
                        "in order to use this tokenizer."
                    )
629

630
        raise ValueError(
631
632
            f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n"
            f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING.keys())}."
633
        )
634
635
636
637
638
639
640

    def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None):
        """
        Register a new tokenizer in this mapping.


        Args:
641
            config_class ([`PretrainedConfig`]):
642
                The configuration corresponding to the model to register.
643
            slow_tokenizer_class ([`PretrainedTokenizer`], *optional*):
644
                The slow tokenizer to register.
645
            slow_tokenizer_class ([`PretrainedTokenizerFast`], *optional*):
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
                The fast tokenizer to register.
        """
        if slow_tokenizer_class is None and fast_tokenizer_class is None:
            raise ValueError("You need to pass either a `slow_tokenizer_class` or a `fast_tokenizer_class")
        if slow_tokenizer_class is not None and issubclass(slow_tokenizer_class, PreTrainedTokenizerFast):
            raise ValueError("You passed a fast tokenizer in the `slow_tokenizer_class`.")
        if fast_tokenizer_class is not None and issubclass(fast_tokenizer_class, PreTrainedTokenizer):
            raise ValueError("You passed a slow tokenizer in the `fast_tokenizer_class`.")

        if (
            slow_tokenizer_class is not None
            and fast_tokenizer_class is not None
            and issubclass(fast_tokenizer_class, PreTrainedTokenizerFast)
            and fast_tokenizer_class.slow_tokenizer_class != slow_tokenizer_class
        ):
            raise ValueError(
                "The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not "
                "consistent with the slow tokenizer class you passed (fast tokenizer has "
                f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those "
                "so they match!"
            )

        # Avoid resetting a set slow/fast tokenizer if we are passing just the other ones.
        if config_class in TOKENIZER_MAPPING._extra_content:
            existing_slow, existing_fast = TOKENIZER_MAPPING[config_class]
            if slow_tokenizer_class is None:
                slow_tokenizer_class = existing_slow
            if fast_tokenizer_class is None:
                fast_tokenizer_class = existing_fast

        TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class))