configuration_auto.py 27.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" Auto Config class."""
16
import importlib
17
import re
18
import warnings
19
from collections import OrderedDict
20
from typing import List, Union
21

Sylvain Gugger's avatar
Sylvain Gugger committed
22
from ...configuration_utils import PretrainedConfig
23
from ...file_utils import CONFIG_NAME
24
25
from ...utils import logging
from .dynamic import get_class_from_dynamic_module
Aymeric Augustin's avatar
Aymeric Augustin committed
26

27

28
29
logger = logging.get_logger(__name__)

30
31
32
CONFIG_MAPPING_NAMES = OrderedDict(
    [
        # Add configs here
NielsRogge's avatar
NielsRogge committed
33
        ("vit_mae", "ViTMAEConfig"),
Li-Huai (Allan) Lin's avatar
Li-Huai (Allan) Lin committed
34
        ("realm", "RealmConfig"),
novice's avatar
novice committed
35
        ("nystromformer", "NystromformerConfig"),
NielsRogge's avatar
NielsRogge committed
36
        ("imagegpt", "ImageGPTConfig"),
37
        ("qdqbert", "QDQBertConfig"),
38
39
        ("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
        ("trocr", "TrOCRConfig"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
40
        ("fnet", "FNetConfig"),
NielsRogge's avatar
NielsRogge committed
41
        ("segformer", "SegformerConfig"),
Suraj Patil's avatar
Suraj Patil committed
42
        ("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
NielsRogge's avatar
NielsRogge committed
43
        ("perceiver", "PerceiverConfig"),
Stella Biderman's avatar
Stella Biderman committed
44
        ("gptj", "GPTJConfig"),
45
        ("layoutlmv2", "LayoutLMv2Config"),
46
47
48
49
50
51
52
53
54
55
56
57
        ("beit", "BeitConfig"),
        ("rembert", "RemBertConfig"),
        ("visual_bert", "VisualBertConfig"),
        ("canine", "CanineConfig"),
        ("roformer", "RoFormerConfig"),
        ("clip", "CLIPConfig"),
        ("bigbird_pegasus", "BigBirdPegasusConfig"),
        ("deit", "DeiTConfig"),
        ("luke", "LukeConfig"),
        ("detr", "DetrConfig"),
        ("gpt_neo", "GPTNeoConfig"),
        ("big_bird", "BigBirdConfig"),
58
        ("speech_to_text_2", "Speech2Text2Config"),
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        ("speech_to_text", "Speech2TextConfig"),
        ("vit", "ViTConfig"),
        ("wav2vec2", "Wav2Vec2Config"),
        ("m2m_100", "M2M100Config"),
        ("convbert", "ConvBertConfig"),
        ("led", "LEDConfig"),
        ("blenderbot-small", "BlenderbotSmallConfig"),
        ("retribert", "RetriBertConfig"),
        ("ibert", "IBertConfig"),
        ("mt5", "MT5Config"),
        ("t5", "T5Config"),
        ("mobilebert", "MobileBertConfig"),
        ("distilbert", "DistilBertConfig"),
        ("albert", "AlbertConfig"),
        ("bert-generation", "BertGenerationConfig"),
        ("camembert", "CamembertConfig"),
        ("xlm-roberta", "XLMRobertaConfig"),
        ("pegasus", "PegasusConfig"),
        ("marian", "MarianConfig"),
        ("mbart", "MBartConfig"),
        ("megatron-bert", "MegatronBertConfig"),
        ("mpnet", "MPNetConfig"),
        ("bart", "BartConfig"),
        ("blenderbot", "BlenderbotConfig"),
        ("reformer", "ReformerConfig"),
        ("longformer", "LongformerConfig"),
        ("roberta", "RobertaConfig"),
        ("deberta-v2", "DebertaV2Config"),
        ("deberta", "DebertaConfig"),
        ("flaubert", "FlaubertConfig"),
        ("fsmt", "FSMTConfig"),
        ("squeezebert", "SqueezeBertConfig"),
        ("hubert", "HubertConfig"),
        ("bert", "BertConfig"),
        ("openai-gpt", "OpenAIGPTConfig"),
        ("gpt2", "GPT2Config"),
        ("transfo-xl", "TransfoXLConfig"),
        ("xlnet", "XLNetConfig"),
        ("xlm-prophetnet", "XLMProphetNetConfig"),
        ("prophetnet", "ProphetNetConfig"),
        ("xlm", "XLMConfig"),
        ("ctrl", "CTRLConfig"),
        ("electra", "ElectraConfig"),
102
        ("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
103
104
105
106
107
108
109
        ("encoder-decoder", "EncoderDecoderConfig"),
        ("funnel", "FunnelConfig"),
        ("lxmert", "LxmertConfig"),
        ("dpr", "DPRConfig"),
        ("layoutlm", "LayoutLMConfig"),
        ("rag", "RagConfig"),
        ("tapas", "TapasConfig"),
Ori Ram's avatar
Ori Ram committed
110
        ("splinter", "SplinterConfig"),
111
112
        ("sew-d", "SEWDConfig"),
        ("sew", "SEWConfig"),
113
114
        ("unispeech-sat", "UniSpeechSatConfig"),
        ("unispeech", "UniSpeechConfig"),
Patrick von Platen's avatar
Patrick von Platen committed
115
        ("wavlm", "WavLMConfig"),
116
117
    ]
)
118

119
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
120
    [
121
        # Add archive maps here
NielsRogge's avatar
NielsRogge committed
122
        ("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Li-Huai (Allan) Lin's avatar
Li-Huai (Allan) Lin committed
123
        ("realm", "REALM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
novice's avatar
novice committed
124
        ("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
125
        ("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
126
        ("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
127
        ("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
128
        ("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
129
        ("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
130
        ("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Stella Biderman's avatar
Stella Biderman committed
131
        ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
132
        ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
147
        ("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        ("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Ori Ram's avatar
Ori Ram committed
188
        ("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
189
190
        ("sew-d", "SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("sew", "SEW_PRETRAINED_CONFIG_ARCHIVE_MAP"),
191
192
        ("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"),
193
194
195
196
197
    ]
)

MODEL_NAMES_MAPPING = OrderedDict(
    [
198
        # Add full (and cased) model names here
NielsRogge's avatar
NielsRogge committed
199
        ("vit_mae", "ViTMAE"),
Li-Huai (Allan) Lin's avatar
Li-Huai (Allan) Lin committed
200
        ("realm", "Realm"),
novice's avatar
novice committed
201
        ("nystromformer", "Nystromformer"),
NielsRogge's avatar
NielsRogge committed
202
        ("imagegpt", "ImageGPT"),
203
        ("qdqbert", "QDQBert"),
204
205
        ("vision-encoder-decoder", "Vision Encoder decoder"),
        ("trocr", "TrOCR"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
206
        ("fnet", "FNet"),
NielsRogge's avatar
NielsRogge committed
207
        ("segformer", "SegFormer"),
Suraj Patil's avatar
Suraj Patil committed
208
        ("vision-text-dual-encoder", "VisionTextDualEncoder"),
NielsRogge's avatar
NielsRogge committed
209
        ("perceiver", "Perceiver"),
Stella Biderman's avatar
Stella Biderman committed
210
        ("gptj", "GPT-J"),
NielsRogge's avatar
NielsRogge committed
211
        ("beit", "BEiT"),
212
        ("rembert", "RemBERT"),
213
        ("layoutlmv2", "LayoutLMv2"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
214
        ("visual_bert", "VisualBert"),
NielsRogge's avatar
NielsRogge committed
215
        ("canine", "Canine"),
216
        ("roformer", "RoFormer"),
Suraj Patil's avatar
Suraj Patil committed
217
        ("clip", "CLIP"),
Vasudev Gupta's avatar
Vasudev Gupta committed
218
        ("bigbird_pegasus", "BigBirdPegasus"),
NielsRogge's avatar
NielsRogge committed
219
        ("deit", "DeiT"),
NielsRogge's avatar
NielsRogge committed
220
        ("luke", "LUKE"),
NielsRogge's avatar
NielsRogge committed
221
        ("detr", "DETR"),
Suraj Patil's avatar
Suraj Patil committed
222
        ("gpt_neo", "GPT Neo"),
Vasudev Gupta's avatar
Vasudev Gupta committed
223
        ("big_bird", "BigBird"),
224
        ("speech_to_text_2", "Speech2Text2"),
Suraj Patil's avatar
Suraj Patil committed
225
        ("speech_to_text", "Speech2Text"),
226
        ("vit", "ViT"),
Patrick von Platen's avatar
Patrick von Platen committed
227
        ("wav2vec2", "Wav2Vec2"),
Suraj Patil's avatar
Suraj Patil committed
228
        ("m2m_100", "M2M100"),
abhishek thakur's avatar
abhishek thakur committed
229
        ("convbert", "ConvBERT"),
Patrick von Platen's avatar
Patrick von Platen committed
230
        ("led", "LED"),
231
        ("blenderbot-small", "BlenderbotSmall"),
232
        ("retribert", "RetriBERT"),
Sehoon Kim's avatar
Sehoon Kim committed
233
        ("ibert", "I-BERT"),
234
235
236
237
238
239
240
241
        ("t5", "T5"),
        ("mobilebert", "MobileBERT"),
        ("distilbert", "DistilBERT"),
        ("albert", "ALBERT"),
        ("bert-generation", "Bert Generation"),
        ("camembert", "CamemBERT"),
        ("xlm-roberta", "XLM-RoBERTa"),
        ("pegasus", "Pegasus"),
Sam Shleifer's avatar
Sam Shleifer committed
242
        ("blenderbot", "Blenderbot"),
243
244
        ("marian", "Marian"),
        ("mbart", "mBART"),
245
        ("megatron-bert", "MegatronBert"),
246
247
248
249
250
        ("bart", "BART"),
        ("reformer", "Reformer"),
        ("longformer", "Longformer"),
        ("roberta", "RoBERTa"),
        ("flaubert", "FlauBERT"),
251
        ("fsmt", "FairSeq Machine-Translation"),
252
        ("squeezebert", "SqueezeBERT"),
253
254
255
256
257
258
259
260
261
        ("bert", "BERT"),
        ("openai-gpt", "OpenAI GPT"),
        ("gpt2", "OpenAI GPT-2"),
        ("transfo-xl", "Transformer-XL"),
        ("xlnet", "XLNet"),
        ("xlm", "XLM"),
        ("ctrl", "CTRL"),
        ("electra", "ELECTRA"),
        ("encoder-decoder", "Encoder decoder"),
262
        ("speech-encoder-decoder", "Speech Encoder decoder"),
263
        ("vision-encoder-decoder", "Vision Encoder decoder"),
264
265
        ("funnel", "Funnel Transformer"),
        ("lxmert", "LXMERT"),
266
        ("deberta-v2", "DeBERTa-v2"),
Pengcheng He's avatar
Pengcheng He committed
267
        ("deberta", "DeBERTa"),
Minghao Li's avatar
Minghao Li committed
268
        ("layoutlm", "LayoutLM"),
Ola Piktus's avatar
Ola Piktus committed
269
270
        ("dpr", "DPR"),
        ("rag", "RAG"),
Weizhen's avatar
Weizhen committed
271
272
        ("xlm-prophetnet", "XLMProphetNet"),
        ("prophetnet", "ProphetNet"),
Patrick von Platen's avatar
Patrick von Platen committed
273
        ("mt5", "mT5"),
StillKeepTry's avatar
StillKeepTry committed
274
        ("mpnet", "MPNet"),
NielsRogge's avatar
NielsRogge committed
275
        ("tapas", "TAPAS"),
Patrick von Platen's avatar
Patrick von Platen committed
276
        ("hubert", "Hubert"),
277
278
        ("barthez", "BARThez"),
        ("phobert", "PhoBERT"),
279
        ("bartpho", "BARTpho"),
280
281
282
283
284
        ("cpm", "CPM"),
        ("bertweet", "Bertweet"),
        ("bert-japanese", "BertJapanese"),
        ("byt5", "ByT5"),
        ("mbart50", "mBART-50"),
Ori Ram's avatar
Ori Ram committed
285
        ("splinter", "Splinter"),
286
287
        ("sew-d", "SEW-D"),
        ("sew", "SEW"),
288
289
        ("unispeech-sat", "UniSpeechSat"),
        ("unispeech", "UniSpeech"),
Patrick von Platen's avatar
Patrick von Platen committed
290
        ("wavlm", "WavLM"),
291
292
293
294
295
296
297
298
299
300
        ("bort", "BORT"),
        ("dialogpt", "DialoGPT"),
        ("xls_r", "XLS-R"),
        ("t5v1.1", "T5v1.1"),
        ("herbert", "HerBERT"),
        ("wav2vec2_phoneme", "Wav2Vec2Phoneme"),
        ("megatron_gpt2", "MegatronGPT2"),
        ("xlsr_wav2vec2", "XLSR-Wav2Vec2"),
        ("mluke", "mLUKE"),
        ("layoutxlm", "LayoutXLM"),
301
302
303
    ]
)

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict([("openai-gpt", "openai")])


def model_type_to_module_name(key):
    """Converts a config key to the corresponding module."""
    # Special treatment
    if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
        return SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]

    return key.replace("-", "_")


def config_class_to_model_type(config):
    """Converts a config class name to the corresponding model type"""
    for key, cls in CONFIG_MAPPING_NAMES.items():
        if cls == config:
            return key
    return None


class _LazyConfigMapping(OrderedDict):
    """
    A dictionary that lazily load its values when they are requested.
    """

    def __init__(self, mapping):
        self._mapping = mapping
331
        self._extra_content = {}
332
333
334
        self._modules = {}

    def __getitem__(self, key):
335
336
        if key in self._extra_content:
            return self._extra_content[key]
337
338
339
340
341
342
343
344
345
        if key not in self._mapping:
            raise KeyError(key)
        value = self._mapping[key]
        module_name = model_type_to_module_name(key)
        if module_name not in self._modules:
            self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
        return getattr(self._modules[module_name], value)

    def keys(self):
346
        return list(self._mapping.keys()) + list(self._extra_content.keys())
347
348

    def values(self):
349
        return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
350

351
    def items(self):
352
        return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
353
354

    def __iter__(self):
355
        return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
356
357

    def __contains__(self, item):
358
359
360
361
362
363
364
365
366
        return item in self._mapping or item in self._extra_content

    def register(self, key, value):
        """
        Register a new configuration in this mapping.
        """
        if key in self._mapping.keys():
            raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
        self._extra_content[key] = value
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431


CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)


class _LazyLoadAllMappings(OrderedDict):
    """
    A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
    etc.)

    Args:
        mapping: The mapping to load.
    """

    def __init__(self, mapping):
        self._mapping = mapping
        self._initialized = False
        self._data = {}

    def _initialize(self):
        if self._initialized:
            return
        warnings.warn(
            "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. "
            "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.",
            FutureWarning,
        )

        for model_type, map_name in self._mapping.items():
            module_name = model_type_to_module_name(model_type)
            module = importlib.import_module(f".{module_name}", "transformers.models")
            mapping = getattr(module, map_name)
            self._data.update(mapping)

        self._initialized = True

    def __getitem__(self, key):
        self._initialize()
        return self._data[key]

    def keys(self):
        self._initialize()
        return self._data.keys()

    def values(self):
        self._initialize()
        return self._data.values()

    def items(self):
        self._initialize()
        return self._data.keys()

    def __iter__(self):
        self._initialize()
        return iter(self._data)

    def __contains__(self, item):
        self._initialize()
        return item in self._data


ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPPING_NAMES)


def _get_class_name(model_class: Union[str, List[str]]):
432
    if isinstance(model_class, (list, tuple)):
Stas Bekman's avatar
Stas Bekman committed
433
434
        return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
    return f"[`{model_class}`]"
435
436


437
438
439
440
441
def _list_model_options(indent, config_to_class=None, use_model_types=True):
    if config_to_class is None and not use_model_types:
        raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
    if use_model_types:
        if config_to_class is None:
Stas Bekman's avatar
Stas Bekman committed
442
            model_type_to_name = {model_type: f"[`{config}`]" for model_type, config in CONFIG_MAPPING_NAMES.items()}
443
444
        else:
            model_type_to_name = {
445
446
447
                model_type: _get_class_name(model_class)
                for model_type, model_class in config_to_class.items()
                if model_type in MODEL_NAMES_MAPPING
448
449
            }
        lines = [
450
            f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
451
            for model_type in sorted(model_type_to_name.keys())
452
453
        ]
    else:
454
455
456
457
458
        config_to_name = {
            CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
            for config, clas in config_to_class.items()
            if config in CONFIG_MAPPING_NAMES
        }
459
        config_to_model_name = {
460
            config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
461
462
        }
        lines = [
Stas Bekman's avatar
Stas Bekman committed
463
            f"{indent}- [`{config_name}`] configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
464
            for config_name in sorted(config_to_name.keys())
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        ]
    return "\n".join(lines)


def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True):
    def docstring_decorator(fn):
        docstrings = fn.__doc__
        lines = docstrings.split("\n")
        i = 0
        while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None:
            i += 1
        if i < len(lines):
            indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0]
            if use_model_types:
                indent = f"{indent}    "
            lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types)
            docstrings = "\n".join(lines)
        else:
            raise ValueError(
                f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current docstring is:\n{docstrings}"
            )
        fn.__doc__ = docstrings
        return fn

    return docstring_decorator


Julien Chaumond's avatar
Julien Chaumond committed
492
class AutoConfig:
Lysandre Debut's avatar
Lysandre Debut committed
493
    r"""
494
    This is a generic configuration class that will be instantiated as one of the configuration classes of the library
495
    when created with the [`~AutoConfig.from_pretrained`] class method.
496

497
    This class cannot be instantiated directly using `__init__()` (throws an error).
498
    """
499

500
    def __init__(self):
501
502
503
504
        raise EnvironmentError(
            "AutoConfig is designed to be instantiated "
            "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
        )
505

506
    @classmethod
507
508
509
510
    def for_model(cls, model_type: str, *args, **kwargs):
        if model_type in CONFIG_MAPPING:
            config_class = CONFIG_MAPPING[model_type]
            return config_class(*args, **kwargs)
511
        raise ValueError(
512
            f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
513
        )
514

515
    @classmethod
516
    @replace_list_option_in_docstrings()
517
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
518
519
        r"""
        Instantiate one of the configuration classes of the library from a pretrained model configuration.
520

Sylvain Gugger's avatar
Sylvain Gugger committed
521
522
        The configuration class to instantiate is selected based on the `model_type` property of the config object that
        is loaded, or when it's missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
523

524
        List options
Lysandre Debut's avatar
Lysandre Debut committed
525
526

        Args:
527
            pretrained_model_name_or_path (`str` or `os.PathLike`):
528
529
                Can be either:

530
531
532
533
                    - A string, the *model id* of a pretrained model configuration hosted inside a model repo on
                      huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
                      namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
                    - A path to a *directory* containing a configuration file saved using the
Sylvain Gugger's avatar
Sylvain Gugger committed
534
535
                      [`~PretrainedConfig.save_pretrained`] method, or the [`~PreTrainedModel.save_pretrained`] method,
                      e.g., `./my_model_directory/`.
536
537
538
                    - A path or url to a saved configuration JSON *file*, e.g.,
                      `./my_model_directory/configuration.json`.
            cache_dir (`str` or `os.PathLike`, *optional*):
539
540
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
541
            force_download (`bool`, *optional*, defaults to `False`):
542
543
                Whether or not to force the (re-)download the model weights and configuration files and override the
                cached versions if they exist.
544
            resume_download (`bool`, *optional*, defaults to `False`):
545
546
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
547
            proxies (`Dict[str, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
548
549
                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
550
            revision(`str`, *optional*, defaults to `"main"`):
Julien Chaumond's avatar
Julien Chaumond committed
551
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
552
                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
Julien Chaumond's avatar
Julien Chaumond committed
553
                identifier allowed by git.
554
555
            return_unused_kwargs (`bool`, *optional*, defaults to `False`):
                If `False`, then this function returns just the final configuration object.
556

Sylvain Gugger's avatar
Sylvain Gugger committed
557
558
559
                If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
                dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
                part of `kwargs` which has not been used to update `config` and is otherwise ignored.
560
            trust_remote_code (`bool`, *optional*, defaults to `False`):
561
                Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
Sylvain Gugger's avatar
Sylvain Gugger committed
562
563
                should only be set to `True` for repositories you trust and in which you have read the code, as it will
                execute code present on the Hub on your local machine.
564
            kwargs(additional keyword arguments, *optional*):
565
                The values in kwargs of any keys which are configuration attributes will be used to override the loaded
Sylvain Gugger's avatar
Sylvain Gugger committed
566
                values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
567
                by the `return_unused_kwargs` keyword parameter.
Lysandre Debut's avatar
Lysandre Debut committed
568

569
        Examples:
570

571
572
        ```python
        >>> from transformers import AutoConfig
573

574
        >>> # Download configuration from huggingface.co and cache.
Sylvain Gugger's avatar
Sylvain Gugger committed
575
        >>> config = AutoConfig.from_pretrained("bert-base-uncased")
Lysandre Debut's avatar
Lysandre Debut committed
576

577
        >>> # Download configuration from huggingface.co (user-uploaded) and cache.
Sylvain Gugger's avatar
Sylvain Gugger committed
578
        >>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")
Lysandre Debut's avatar
Lysandre Debut committed
579

580
        >>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
Sylvain Gugger's avatar
Sylvain Gugger committed
581
        >>> config = AutoConfig.from_pretrained("./test/bert_saved_model/")
582

583
        >>> # Load a specific configuration file.
Sylvain Gugger's avatar
Sylvain Gugger committed
584
        >>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")
585

586
        >>> # Change some config attributes when loading a pretrained config.
Sylvain Gugger's avatar
Sylvain Gugger committed
587
        >>> config = AutoConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
588
589
        >>> config.output_attentions
        True
Sylvain Gugger's avatar
Sylvain Gugger committed
590
591
592
593

        >>> config, unused_kwargs = AutoConfig.from_pretrained(
        ...     "bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
        ... )
594
595
        >>> config.output_attentions
        True
Sylvain Gugger's avatar
Sylvain Gugger committed
596

597
598
599
        >>> config.unused_kwargs
        {'foo': False}
        ```"""
600
        kwargs["_from_auto"] = True
601
602
        kwargs["name_or_path"] = pretrained_model_name_or_path
        trust_remote_code = kwargs.pop("trust_remote_code", False)
603
        config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
        if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
            if not trust_remote_code:
                raise ValueError(
                    f"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that repo "
                    "on your local machine. Make sure you have read the code there to avoid malicious use, then set "
                    "the option `trust_remote_code=True` to remove this error."
                )
            if kwargs.get("revision", None) is None:
                logger.warn(
                    "Explicitly passing a `revision` is encouraged when loading a configuration with custom code to "
                    "ensure no malicious code has been contributed in a newer revision."
                )
            class_ref = config_dict["auto_map"]["AutoConfig"]
            module_file, class_name = class_ref.split(".")
            config_class = get_class_from_dynamic_module(
                pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
            )
            return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
        elif "model_type" in config_dict:
623
624
625
626
627
            config_class = CONFIG_MAPPING[config_dict["model_type"]]
            return config_class.from_dict(config_dict, **kwargs)
        else:
            # Fallback: use pattern matching on the string.
            for pattern, config_class in CONFIG_MAPPING.items():
628
                if pattern in str(pretrained_model_name_or_path):
629
630
                    return config_class.from_dict(config_dict, **kwargs)

631
        raise ValueError(
632
            f"Unrecognized model in {pretrained_model_name_or_path}. "
633
            f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
634
            f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
635
        )
636
637
638
639
640
641
642

    @staticmethod
    def register(model_type, config):
        """
        Register a new configuration for this class.

        Args:
643
644
            model_type (`str`): The model type like "bert" or "gpt".
            config ([`PretrainedConfig`]): The config to register.
645
646
647
648
649
650
651
652
        """
        if issubclass(config, PretrainedConfig) and config.model_type != model_type:
            raise ValueError(
                "The config you are passing has a `model_type` attribute that is not consistent with the model type "
                f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
                "match!"
            )
        CONFIG_MAPPING.register(model_type, config)