configuration_auto.py 22.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
""" Auto Config class. """
16
import importlib
17
import re
18
import warnings
19
from collections import OrderedDict
20
from typing import List, Union
21

Sylvain Gugger's avatar
Sylvain Gugger committed
22
from ...configuration_utils import PretrainedConfig
23
from ...file_utils import CONFIG_NAME
Aymeric Augustin's avatar
Aymeric Augustin committed
24

25

26
27
28
CONFIG_MAPPING_NAMES = OrderedDict(
    [
        # Add configs here
Stella Biderman's avatar
Stella Biderman committed
29
        ("gptj", "GPTJConfig"),
30
        ("layoutlmv2", "LayoutLMv2Config"),
31
32
33
34
35
36
37
38
39
40
41
42
        ("beit", "BeitConfig"),
        ("rembert", "RemBertConfig"),
        ("visual_bert", "VisualBertConfig"),
        ("canine", "CanineConfig"),
        ("roformer", "RoFormerConfig"),
        ("clip", "CLIPConfig"),
        ("bigbird_pegasus", "BigBirdPegasusConfig"),
        ("deit", "DeiTConfig"),
        ("luke", "LukeConfig"),
        ("detr", "DetrConfig"),
        ("gpt_neo", "GPTNeoConfig"),
        ("big_bird", "BigBirdConfig"),
43
        ("speech_to_text_2", "Speech2Text2Config"),
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        ("speech_to_text", "Speech2TextConfig"),
        ("vit", "ViTConfig"),
        ("wav2vec2", "Wav2Vec2Config"),
        ("m2m_100", "M2M100Config"),
        ("convbert", "ConvBertConfig"),
        ("led", "LEDConfig"),
        ("blenderbot-small", "BlenderbotSmallConfig"),
        ("retribert", "RetriBertConfig"),
        ("ibert", "IBertConfig"),
        ("mt5", "MT5Config"),
        ("t5", "T5Config"),
        ("mobilebert", "MobileBertConfig"),
        ("distilbert", "DistilBertConfig"),
        ("albert", "AlbertConfig"),
        ("bert-generation", "BertGenerationConfig"),
        ("camembert", "CamembertConfig"),
        ("xlm-roberta", "XLMRobertaConfig"),
        ("pegasus", "PegasusConfig"),
        ("marian", "MarianConfig"),
        ("mbart", "MBartConfig"),
        ("megatron-bert", "MegatronBertConfig"),
        ("mpnet", "MPNetConfig"),
        ("bart", "BartConfig"),
        ("blenderbot", "BlenderbotConfig"),
        ("reformer", "ReformerConfig"),
        ("longformer", "LongformerConfig"),
        ("roberta", "RobertaConfig"),
        ("deberta-v2", "DebertaV2Config"),
        ("deberta", "DebertaConfig"),
        ("flaubert", "FlaubertConfig"),
        ("fsmt", "FSMTConfig"),
        ("squeezebert", "SqueezeBertConfig"),
        ("hubert", "HubertConfig"),
        ("bert", "BertConfig"),
        ("openai-gpt", "OpenAIGPTConfig"),
        ("gpt2", "GPT2Config"),
        ("transfo-xl", "TransfoXLConfig"),
        ("xlnet", "XLNetConfig"),
        ("xlm-prophetnet", "XLMProphetNetConfig"),
        ("prophetnet", "ProphetNetConfig"),
        ("xlm", "XLMConfig"),
        ("ctrl", "CTRLConfig"),
        ("electra", "ElectraConfig"),
87
        ("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
88
89
90
91
92
93
94
        ("encoder-decoder", "EncoderDecoderConfig"),
        ("funnel", "FunnelConfig"),
        ("lxmert", "LxmertConfig"),
        ("dpr", "DPRConfig"),
        ("layoutlm", "LayoutLMConfig"),
        ("rag", "RagConfig"),
        ("tapas", "TapasConfig"),
Ori Ram's avatar
Ori Ram committed
95
        ("splinter", "SplinterConfig"),
96
97
    ]
)
98

99
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
100
    [
101
        # Add archive maps here
102
        ("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Stella Biderman's avatar
Stella Biderman committed
103
        ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
104
        ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
119
        ("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        ("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Ori Ram's avatar
Ori Ram committed
160
        ("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
161
162
163
164
165
    ]
)

MODEL_NAMES_MAPPING = OrderedDict(
    [
166
        # Add full (and cased) model names here
Stella Biderman's avatar
Stella Biderman committed
167
        ("gptj", "GPT-J"),
NielsRogge's avatar
NielsRogge committed
168
        ("beit", "BeiT"),
169
        ("rembert", "RemBERT"),
170
        ("layoutlmv2", "LayoutLMv2"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
171
        ("visual_bert", "VisualBert"),
NielsRogge's avatar
NielsRogge committed
172
        ("canine", "Canine"),
173
        ("roformer", "RoFormer"),
Suraj Patil's avatar
Suraj Patil committed
174
        ("clip", "CLIP"),
Vasudev Gupta's avatar
Vasudev Gupta committed
175
        ("bigbird_pegasus", "BigBirdPegasus"),
NielsRogge's avatar
NielsRogge committed
176
        ("deit", "DeiT"),
NielsRogge's avatar
NielsRogge committed
177
        ("luke", "LUKE"),
NielsRogge's avatar
NielsRogge committed
178
        ("detr", "DETR"),
Suraj Patil's avatar
Suraj Patil committed
179
        ("gpt_neo", "GPT Neo"),
Vasudev Gupta's avatar
Vasudev Gupta committed
180
        ("big_bird", "BigBird"),
181
        ("speech_to_text_2", "Speech2Text2"),
Suraj Patil's avatar
Suraj Patil committed
182
        ("speech_to_text", "Speech2Text"),
183
        ("vit", "ViT"),
Patrick von Platen's avatar
Patrick von Platen committed
184
        ("wav2vec2", "Wav2Vec2"),
Suraj Patil's avatar
Suraj Patil committed
185
        ("m2m_100", "M2M100"),
abhishek thakur's avatar
abhishek thakur committed
186
        ("convbert", "ConvBERT"),
Patrick von Platen's avatar
Patrick von Platen committed
187
        ("led", "LED"),
188
        ("blenderbot-small", "BlenderbotSmall"),
189
        ("retribert", "RetriBERT"),
Sehoon Kim's avatar
Sehoon Kim committed
190
        ("ibert", "I-BERT"),
191
192
193
194
195
196
197
198
        ("t5", "T5"),
        ("mobilebert", "MobileBERT"),
        ("distilbert", "DistilBERT"),
        ("albert", "ALBERT"),
        ("bert-generation", "Bert Generation"),
        ("camembert", "CamemBERT"),
        ("xlm-roberta", "XLM-RoBERTa"),
        ("pegasus", "Pegasus"),
Sam Shleifer's avatar
Sam Shleifer committed
199
        ("blenderbot", "Blenderbot"),
200
201
        ("marian", "Marian"),
        ("mbart", "mBART"),
202
        ("megatron-bert", "MegatronBert"),
203
204
205
206
207
        ("bart", "BART"),
        ("reformer", "Reformer"),
        ("longformer", "Longformer"),
        ("roberta", "RoBERTa"),
        ("flaubert", "FlauBERT"),
208
        ("fsmt", "FairSeq Machine-Translation"),
209
        ("squeezebert", "SqueezeBERT"),
210
211
212
213
214
215
216
217
218
        ("bert", "BERT"),
        ("openai-gpt", "OpenAI GPT"),
        ("gpt2", "OpenAI GPT-2"),
        ("transfo-xl", "Transformer-XL"),
        ("xlnet", "XLNet"),
        ("xlm", "XLM"),
        ("ctrl", "CTRL"),
        ("electra", "ELECTRA"),
        ("encoder-decoder", "Encoder decoder"),
219
        ("speech-encoder-decoder", "Speech Encoder decoder"),
220
221
        ("funnel", "Funnel Transformer"),
        ("lxmert", "LXMERT"),
222
        ("deberta-v2", "DeBERTa-v2"),
Pengcheng He's avatar
Pengcheng He committed
223
        ("deberta", "DeBERTa"),
Minghao Li's avatar
Minghao Li committed
224
        ("layoutlm", "LayoutLM"),
Ola Piktus's avatar
Ola Piktus committed
225
226
        ("dpr", "DPR"),
        ("rag", "RAG"),
Weizhen's avatar
Weizhen committed
227
228
        ("xlm-prophetnet", "XLMProphetNet"),
        ("prophetnet", "ProphetNet"),
Patrick von Platen's avatar
Patrick von Platen committed
229
        ("mt5", "mT5"),
StillKeepTry's avatar
StillKeepTry committed
230
        ("mpnet", "MPNet"),
NielsRogge's avatar
NielsRogge committed
231
        ("tapas", "TAPAS"),
Patrick von Platen's avatar
Patrick von Platen committed
232
        ("hubert", "Hubert"),
233
234
235
236
237
238
239
        ("barthez", "BARThez"),
        ("phobert", "PhoBERT"),
        ("cpm", "CPM"),
        ("bertweet", "Bertweet"),
        ("bert-japanese", "BertJapanese"),
        ("byt5", "ByT5"),
        ("mbart50", "mBART-50"),
Ori Ram's avatar
Ori Ram committed
240
        ("splinter", "Splinter"),
241
242
243
    ]
)

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict([("openai-gpt", "openai")])


def model_type_to_module_name(key):
    """Converts a config key to the corresponding module."""
    # Special treatment
    if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
        return SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]

    return key.replace("-", "_")


def config_class_to_model_type(config):
    """Converts a config class name to the corresponding model type"""
    for key, cls in CONFIG_MAPPING_NAMES.items():
        if cls == config:
            return key
    return None


class _LazyConfigMapping(OrderedDict):
    """
    A dictionary that lazily load its values when they are requested.
    """

    def __init__(self, mapping):
        self._mapping = mapping
        self._modules = {}

    def __getitem__(self, key):
        if key not in self._mapping:
            raise KeyError(key)
        value = self._mapping[key]
        module_name = model_type_to_module_name(key)
        if module_name not in self._modules:
            self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
        return getattr(self._modules[module_name], value)

    def keys(self):
        return self._mapping.keys()

    def values(self):
        return [self[k] for k in self._mapping.keys()]
287

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
    def items(self):
        return [(k, self[k]) for k in self._mapping.keys()]

    def __iter__(self):
        return iter(self._mapping.keys())

    def __contains__(self, item):
        return item in self._mapping


CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)


class _LazyLoadAllMappings(OrderedDict):
    """
    A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
    etc.)

    Args:
        mapping: The mapping to load.
    """

    def __init__(self, mapping):
        self._mapping = mapping
        self._initialized = False
        self._data = {}

    def _initialize(self):
        if self._initialized:
            return
        warnings.warn(
            "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. "
            "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.",
            FutureWarning,
        )

        for model_type, map_name in self._mapping.items():
            module_name = model_type_to_module_name(model_type)
            module = importlib.import_module(f".{module_name}", "transformers.models")
            mapping = getattr(module, map_name)
            self._data.update(mapping)

        self._initialized = True

    def __getitem__(self, key):
        self._initialize()
        return self._data[key]

    def keys(self):
        self._initialize()
        return self._data.keys()

    def values(self):
        self._initialize()
        return self._data.values()

    def items(self):
        self._initialize()
        return self._data.keys()

    def __iter__(self):
        self._initialize()
        return iter(self._data)

    def __contains__(self, item):
        self._initialize()
        return item in self._data


ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPPING_NAMES)


def _get_class_name(model_class: Union[str, List[str]]):
361
    if isinstance(model_class, (list, tuple)):
362
363
        return " or ".join([f":class:`~transformers.{c}`" for c in model_class if c is not None])
    return f":class:`~transformers.{model_class}`"
364
365


366
367
368
369
370
def _list_model_options(indent, config_to_class=None, use_model_types=True):
    if config_to_class is None and not use_model_types:
        raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
    if use_model_types:
        if config_to_class is None:
371
            model_type_to_name = {
372
                model_type: f":class:`~transformers.{config}`" for model_type, config in CONFIG_MAPPING_NAMES.items()
373
            }
374
375
        else:
            model_type_to_name = {
376
377
378
                model_type: _get_class_name(model_class)
                for model_type, model_class in config_to_class.items()
                if model_type in MODEL_NAMES_MAPPING
379
380
            }
        lines = [
381
            f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
382
            for model_type in sorted(model_type_to_name.keys())
383
384
        ]
    else:
385
386
387
388
389
        config_to_name = {
            CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
            for config, clas in config_to_class.items()
            if config in CONFIG_MAPPING_NAMES
        }
390
        config_to_model_name = {
391
            config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
392
393
        }
        lines = [
394
            f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
395
            for config_name in sorted(config_to_name.keys())
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
        ]
    return "\n".join(lines)


def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True):
    def docstring_decorator(fn):
        docstrings = fn.__doc__
        lines = docstrings.split("\n")
        i = 0
        while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None:
            i += 1
        if i < len(lines):
            indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0]
            if use_model_types:
                indent = f"{indent}    "
            lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types)
            docstrings = "\n".join(lines)
        else:
            raise ValueError(
                f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current docstring is:\n{docstrings}"
            )
        fn.__doc__ = docstrings
        return fn

    return docstring_decorator


Julien Chaumond's avatar
Julien Chaumond committed
423
class AutoConfig:
Lysandre Debut's avatar
Lysandre Debut committed
424
    r"""
425
426
    This is a generic configuration class that will be instantiated as one of the configuration classes of the library
    when created with the :meth:`~transformers.AutoConfig.from_pretrained` class method.
427

428
    This class cannot be instantiated directly using ``__init__()`` (throws an error).
429
    """
430

431
    def __init__(self):
432
433
434
435
        raise EnvironmentError(
            "AutoConfig is designed to be instantiated "
            "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
        )
436

437
    @classmethod
438
439
440
441
    def for_model(cls, model_type: str, *args, **kwargs):
        if model_type in CONFIG_MAPPING:
            config_class = CONFIG_MAPPING[model_type]
            return config_class(*args, **kwargs)
442
        raise ValueError(
443
            f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
444
        )
445

446
    @classmethod
447
    @replace_list_option_in_docstrings()
448
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
449
450
        r"""
        Instantiate one of the configuration classes of the library from a pretrained model configuration.
451

Sylvain Gugger's avatar
Sylvain Gugger committed
452
453
        The configuration class to instantiate is selected based on the :obj:`model_type` property of the config object
        that is loaded, or when it's missing, by falling back to using pattern matching on
454
        :obj:`pretrained_model_name_or_path`:
455

456
        List options
Lysandre Debut's avatar
Lysandre Debut committed
457
458

        Args:
459
            pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
460
461
                Can be either:

462
463
464
                    - A string, the `model id` of a pretrained model configuration hosted inside a model repo on
                      huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
                      namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
465
466
                    - A path to a `directory` containing a configuration file saved using the
                      :meth:`~transformers.PretrainedConfig.save_pretrained` method, or the
467
                      :meth:`~transformers.PreTrainedModel.save_pretrained` method, e.g., ``./my_model_directory/``.
468
469
                    - A path or url to a saved configuration JSON `file`, e.g.,
                      ``./my_model_directory/configuration.json``.
470
            cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
471
472
473
474
475
476
477
478
479
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download the model weights and configuration files and override the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies (:obj:`Dict[str, str]`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
480
481
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
Julien Chaumond's avatar
Julien Chaumond committed
482
483
484
485
            revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
                identifier allowed by git.
486
487
488
489
490
491
492
493
            return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
                If :obj:`False`, then this function returns just the final configuration object.

                If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs`
                is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e.,
                the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored.
            kwargs(additional keyword arguments, `optional`):
                The values in kwargs of any keys which are configuration attributes will be used to override the loaded
Sylvain Gugger's avatar
Sylvain Gugger committed
494
495
                values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
                by the ``return_unused_kwargs`` keyword parameter.
Lysandre Debut's avatar
Lysandre Debut committed
496

497
        Examples::
498

499
            >>> from transformers import AutoConfig
500

501
            >>> # Download configuration from huggingface.co and cache.
502
            >>> config = AutoConfig.from_pretrained('bert-base-uncased')
Lysandre Debut's avatar
Lysandre Debut committed
503

504
            >>> # Download configuration from huggingface.co (user-uploaded) and cache.
505
            >>> config = AutoConfig.from_pretrained('dbmdz/bert-base-german-cased')
Lysandre Debut's avatar
Lysandre Debut committed
506

507
508
            >>> # If configuration file is in a directory (e.g., was saved using `save_pretrained('./test/saved_model/')`).
            >>> config = AutoConfig.from_pretrained('./test/bert_saved_model/')
509

510
511
            >>> # Load a specific configuration file.
            >>> config = AutoConfig.from_pretrained('./test/bert_saved_model/my_configuration.json')
512

513
514
515
516
517
518
519
520
521
            >>> # Change some config attributes when loading a pretrained config.
            >>> config = AutoConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False)
            >>> config.output_attentions
            True
            >>> config, unused_kwargs = AutoConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False, return_unused_kwargs=True)
            >>> config.output_attentions
            True
            >>> config.unused_kwargs
            {'foo': False}
522
        """
523
        kwargs["_from_auto"] = True
524
        config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
525
526
527
528
529
530
        if "model_type" in config_dict:
            config_class = CONFIG_MAPPING[config_dict["model_type"]]
            return config_class.from_dict(config_dict, **kwargs)
        else:
            # Fallback: use pattern matching on the string.
            for pattern, config_class in CONFIG_MAPPING.items():
531
                if pattern in str(pretrained_model_name_or_path):
532
533
                    return config_class.from_dict(config_dict, **kwargs)

534
        raise ValueError(
535
            f"Unrecognized model in {pretrained_model_name_or_path}. "
536
            f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
537
            f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
538
        )