configuration_auto.py 27.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" Auto Config class."""
16
import importlib
17
import re
18
import warnings
19
from collections import OrderedDict
20
from typing import List, Union
21

Sylvain Gugger's avatar
Sylvain Gugger committed
22
from ...configuration_utils import PretrainedConfig
23
from ...file_utils import CONFIG_NAME
24
25
from ...utils import logging
from .dynamic import get_class_from_dynamic_module
Aymeric Augustin's avatar
Aymeric Augustin committed
26

27

28
29
logger = logging.get_logger(__name__)

30
31
32
CONFIG_MAPPING_NAMES = OrderedDict(
    [
        # Add configs here
novice's avatar
novice committed
33
        ("yoso", "YosoConfig"),
novice's avatar
novice committed
34
        ("swin", "SwinConfig"),
NielsRogge's avatar
NielsRogge committed
35
        ("vilt", "ViltConfig"),
NielsRogge's avatar
NielsRogge committed
36
        ("vit_mae", "ViTMAEConfig"),
Li-Huai (Allan) Lin's avatar
Li-Huai (Allan) Lin committed
37
        ("realm", "RealmConfig"),
novice's avatar
novice committed
38
        ("nystromformer", "NystromformerConfig"),
Suraj Patil's avatar
Suraj Patil committed
39
        ("xglm", "XGLMConfig"),
NielsRogge's avatar
NielsRogge committed
40
        ("imagegpt", "ImageGPTConfig"),
41
        ("qdqbert", "QDQBertConfig"),
42
43
        ("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
        ("trocr", "TrOCRConfig"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
44
        ("fnet", "FNetConfig"),
NielsRogge's avatar
NielsRogge committed
45
        ("segformer", "SegformerConfig"),
Suraj Patil's avatar
Suraj Patil committed
46
        ("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
NielsRogge's avatar
NielsRogge committed
47
        ("perceiver", "PerceiverConfig"),
Stella Biderman's avatar
Stella Biderman committed
48
        ("gptj", "GPTJConfig"),
49
        ("layoutlmv2", "LayoutLMv2Config"),
50
51
52
53
54
55
56
57
58
59
60
61
        ("beit", "BeitConfig"),
        ("rembert", "RemBertConfig"),
        ("visual_bert", "VisualBertConfig"),
        ("canine", "CanineConfig"),
        ("roformer", "RoFormerConfig"),
        ("clip", "CLIPConfig"),
        ("bigbird_pegasus", "BigBirdPegasusConfig"),
        ("deit", "DeiTConfig"),
        ("luke", "LukeConfig"),
        ("detr", "DetrConfig"),
        ("gpt_neo", "GPTNeoConfig"),
        ("big_bird", "BigBirdConfig"),
62
        ("speech_to_text_2", "Speech2Text2Config"),
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        ("speech_to_text", "Speech2TextConfig"),
        ("vit", "ViTConfig"),
        ("wav2vec2", "Wav2Vec2Config"),
        ("m2m_100", "M2M100Config"),
        ("convbert", "ConvBertConfig"),
        ("led", "LEDConfig"),
        ("blenderbot-small", "BlenderbotSmallConfig"),
        ("retribert", "RetriBertConfig"),
        ("ibert", "IBertConfig"),
        ("mt5", "MT5Config"),
        ("t5", "T5Config"),
        ("mobilebert", "MobileBertConfig"),
        ("distilbert", "DistilBertConfig"),
        ("albert", "AlbertConfig"),
        ("bert-generation", "BertGenerationConfig"),
        ("camembert", "CamembertConfig"),
        ("xlm-roberta", "XLMRobertaConfig"),
        ("pegasus", "PegasusConfig"),
        ("marian", "MarianConfig"),
        ("mbart", "MBartConfig"),
        ("megatron-bert", "MegatronBertConfig"),
        ("mpnet", "MPNetConfig"),
        ("bart", "BartConfig"),
        ("blenderbot", "BlenderbotConfig"),
        ("reformer", "ReformerConfig"),
        ("longformer", "LongformerConfig"),
        ("roberta", "RobertaConfig"),
        ("deberta-v2", "DebertaV2Config"),
        ("deberta", "DebertaConfig"),
        ("flaubert", "FlaubertConfig"),
        ("fsmt", "FSMTConfig"),
        ("squeezebert", "SqueezeBertConfig"),
        ("hubert", "HubertConfig"),
        ("bert", "BertConfig"),
        ("openai-gpt", "OpenAIGPTConfig"),
        ("gpt2", "GPT2Config"),
        ("transfo-xl", "TransfoXLConfig"),
        ("xlnet", "XLNetConfig"),
        ("xlm-prophetnet", "XLMProphetNetConfig"),
        ("prophetnet", "ProphetNetConfig"),
        ("xlm", "XLMConfig"),
        ("ctrl", "CTRLConfig"),
        ("electra", "ElectraConfig"),
106
        ("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
107
108
109
110
111
112
113
        ("encoder-decoder", "EncoderDecoderConfig"),
        ("funnel", "FunnelConfig"),
        ("lxmert", "LxmertConfig"),
        ("dpr", "DPRConfig"),
        ("layoutlm", "LayoutLMConfig"),
        ("rag", "RagConfig"),
        ("tapas", "TapasConfig"),
Ori Ram's avatar
Ori Ram committed
114
        ("splinter", "SplinterConfig"),
115
116
        ("sew-d", "SEWDConfig"),
        ("sew", "SEWConfig"),
117
118
        ("unispeech-sat", "UniSpeechSatConfig"),
        ("unispeech", "UniSpeechConfig"),
Patrick von Platen's avatar
Patrick von Platen committed
119
        ("wavlm", "WavLMConfig"),
120
121
    ]
)
122

123
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
124
    [
125
        # Add archive maps here
novice's avatar
novice committed
126
        ("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
novice's avatar
novice committed
127
        ("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
128
        ("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
129
        ("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Li-Huai (Allan) Lin's avatar
Li-Huai (Allan) Lin committed
130
        ("realm", "REALM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
novice's avatar
novice committed
131
        ("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Suraj Patil's avatar
Suraj Patil committed
132
        ("xglm", "XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
133
        ("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
134
        ("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
135
        ("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
136
        ("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
137
        ("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
138
        ("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Stella Biderman's avatar
Stella Biderman committed
139
        ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
140
        ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
155
        ("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        ("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Ori Ram's avatar
Ori Ram committed
196
        ("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
197
198
        ("sew-d", "SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("sew", "SEW_PRETRAINED_CONFIG_ARCHIVE_MAP"),
199
200
        ("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"),
201
202
203
204
205
    ]
)

MODEL_NAMES_MAPPING = OrderedDict(
    [
206
        # Add full (and cased) model names here
novice's avatar
novice committed
207
        ("yoso", "YOSO"),
novice's avatar
novice committed
208
        ("swin", "Swin"),
NielsRogge's avatar
NielsRogge committed
209
        ("vilt", "ViLT"),
NielsRogge's avatar
NielsRogge committed
210
        ("vit_mae", "ViTMAE"),
Li-Huai (Allan) Lin's avatar
Li-Huai (Allan) Lin committed
211
        ("realm", "Realm"),
novice's avatar
novice committed
212
        ("nystromformer", "Nystromformer"),
Suraj Patil's avatar
Suraj Patil committed
213
        ("xglm", "XGLM"),
NielsRogge's avatar
NielsRogge committed
214
        ("imagegpt", "ImageGPT"),
215
        ("qdqbert", "QDQBert"),
216
217
        ("vision-encoder-decoder", "Vision Encoder decoder"),
        ("trocr", "TrOCR"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
218
        ("fnet", "FNet"),
NielsRogge's avatar
NielsRogge committed
219
        ("segformer", "SegFormer"),
Suraj Patil's avatar
Suraj Patil committed
220
        ("vision-text-dual-encoder", "VisionTextDualEncoder"),
NielsRogge's avatar
NielsRogge committed
221
        ("perceiver", "Perceiver"),
Stella Biderman's avatar
Stella Biderman committed
222
        ("gptj", "GPT-J"),
NielsRogge's avatar
NielsRogge committed
223
        ("beit", "BEiT"),
224
        ("rembert", "RemBERT"),
225
        ("layoutlmv2", "LayoutLMv2"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
226
        ("visual_bert", "VisualBert"),
NielsRogge's avatar
NielsRogge committed
227
        ("canine", "Canine"),
228
        ("roformer", "RoFormer"),
Suraj Patil's avatar
Suraj Patil committed
229
        ("clip", "CLIP"),
Vasudev Gupta's avatar
Vasudev Gupta committed
230
        ("bigbird_pegasus", "BigBirdPegasus"),
NielsRogge's avatar
NielsRogge committed
231
        ("deit", "DeiT"),
NielsRogge's avatar
NielsRogge committed
232
        ("luke", "LUKE"),
NielsRogge's avatar
NielsRogge committed
233
        ("detr", "DETR"),
Suraj Patil's avatar
Suraj Patil committed
234
        ("gpt_neo", "GPT Neo"),
Vasudev Gupta's avatar
Vasudev Gupta committed
235
        ("big_bird", "BigBird"),
236
        ("speech_to_text_2", "Speech2Text2"),
Suraj Patil's avatar
Suraj Patil committed
237
        ("speech_to_text", "Speech2Text"),
238
        ("vit", "ViT"),
Patrick von Platen's avatar
Patrick von Platen committed
239
        ("wav2vec2", "Wav2Vec2"),
Suraj Patil's avatar
Suraj Patil committed
240
        ("m2m_100", "M2M100"),
abhishek thakur's avatar
abhishek thakur committed
241
        ("convbert", "ConvBERT"),
Patrick von Platen's avatar
Patrick von Platen committed
242
        ("led", "LED"),
243
        ("blenderbot-small", "BlenderbotSmall"),
244
        ("retribert", "RetriBERT"),
Sehoon Kim's avatar
Sehoon Kim committed
245
        ("ibert", "I-BERT"),
246
247
248
249
250
251
252
253
        ("t5", "T5"),
        ("mobilebert", "MobileBERT"),
        ("distilbert", "DistilBERT"),
        ("albert", "ALBERT"),
        ("bert-generation", "Bert Generation"),
        ("camembert", "CamemBERT"),
        ("xlm-roberta", "XLM-RoBERTa"),
        ("pegasus", "Pegasus"),
Sam Shleifer's avatar
Sam Shleifer committed
254
        ("blenderbot", "Blenderbot"),
255
256
        ("marian", "Marian"),
        ("mbart", "mBART"),
257
        ("megatron-bert", "MegatronBert"),
258
259
260
261
262
        ("bart", "BART"),
        ("reformer", "Reformer"),
        ("longformer", "Longformer"),
        ("roberta", "RoBERTa"),
        ("flaubert", "FlauBERT"),
263
        ("fsmt", "FairSeq Machine-Translation"),
264
        ("squeezebert", "SqueezeBERT"),
265
266
267
268
269
270
271
272
273
        ("bert", "BERT"),
        ("openai-gpt", "OpenAI GPT"),
        ("gpt2", "OpenAI GPT-2"),
        ("transfo-xl", "Transformer-XL"),
        ("xlnet", "XLNet"),
        ("xlm", "XLM"),
        ("ctrl", "CTRL"),
        ("electra", "ELECTRA"),
        ("encoder-decoder", "Encoder decoder"),
274
        ("speech-encoder-decoder", "Speech Encoder decoder"),
275
        ("vision-encoder-decoder", "Vision Encoder decoder"),
276
277
        ("funnel", "Funnel Transformer"),
        ("lxmert", "LXMERT"),
278
        ("deberta-v2", "DeBERTa-v2"),
Pengcheng He's avatar
Pengcheng He committed
279
        ("deberta", "DeBERTa"),
Minghao Li's avatar
Minghao Li committed
280
        ("layoutlm", "LayoutLM"),
Ola Piktus's avatar
Ola Piktus committed
281
282
        ("dpr", "DPR"),
        ("rag", "RAG"),
Weizhen's avatar
Weizhen committed
283
284
        ("xlm-prophetnet", "XLMProphetNet"),
        ("prophetnet", "ProphetNet"),
Patrick von Platen's avatar
Patrick von Platen committed
285
        ("mt5", "mT5"),
StillKeepTry's avatar
StillKeepTry committed
286
        ("mpnet", "MPNet"),
NielsRogge's avatar
NielsRogge committed
287
        ("tapas", "TAPAS"),
Patrick von Platen's avatar
Patrick von Platen committed
288
        ("hubert", "Hubert"),
289
290
        ("barthez", "BARThez"),
        ("phobert", "PhoBERT"),
291
        ("bartpho", "BARTpho"),
292
293
294
295
296
        ("cpm", "CPM"),
        ("bertweet", "Bertweet"),
        ("bert-japanese", "BertJapanese"),
        ("byt5", "ByT5"),
        ("mbart50", "mBART-50"),
Ori Ram's avatar
Ori Ram committed
297
        ("splinter", "Splinter"),
298
299
        ("sew-d", "SEW-D"),
        ("sew", "SEW"),
300
301
        ("unispeech-sat", "UniSpeechSat"),
        ("unispeech", "UniSpeech"),
Patrick von Platen's avatar
Patrick von Platen committed
302
        ("wavlm", "WavLM"),
303
304
305
306
307
308
309
310
311
312
        ("bort", "BORT"),
        ("dialogpt", "DialoGPT"),
        ("xls_r", "XLS-R"),
        ("t5v1.1", "T5v1.1"),
        ("herbert", "HerBERT"),
        ("wav2vec2_phoneme", "Wav2Vec2Phoneme"),
        ("megatron_gpt2", "MegatronGPT2"),
        ("xlsr_wav2vec2", "XLSR-Wav2Vec2"),
        ("mluke", "mLUKE"),
        ("layoutxlm", "LayoutXLM"),
313
314
315
    ]
)

316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict([("openai-gpt", "openai")])


def model_type_to_module_name(key):
    """Converts a config key to the corresponding module."""
    # Special treatment
    if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
        return SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]

    return key.replace("-", "_")


def config_class_to_model_type(config):
    """Converts a config class name to the corresponding model type"""
    for key, cls in CONFIG_MAPPING_NAMES.items():
        if cls == config:
            return key
    return None


class _LazyConfigMapping(OrderedDict):
    """
    A dictionary that lazily load its values when they are requested.
    """

    def __init__(self, mapping):
        self._mapping = mapping
343
        self._extra_content = {}
344
345
346
        self._modules = {}

    def __getitem__(self, key):
347
348
        if key in self._extra_content:
            return self._extra_content[key]
349
350
351
352
353
354
355
356
357
        if key not in self._mapping:
            raise KeyError(key)
        value = self._mapping[key]
        module_name = model_type_to_module_name(key)
        if module_name not in self._modules:
            self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
        return getattr(self._modules[module_name], value)

    def keys(self):
358
        return list(self._mapping.keys()) + list(self._extra_content.keys())
359
360

    def values(self):
361
        return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
362

363
    def items(self):
364
        return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
365
366

    def __iter__(self):
367
        return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
368
369

    def __contains__(self, item):
370
371
372
373
374
375
376
377
378
        return item in self._mapping or item in self._extra_content

    def register(self, key, value):
        """
        Register a new configuration in this mapping.
        """
        if key in self._mapping.keys():
            raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
        self._extra_content[key] = value
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443


CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)


class _LazyLoadAllMappings(OrderedDict):
    """
    A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
    etc.)

    Args:
        mapping: The mapping to load.
    """

    def __init__(self, mapping):
        self._mapping = mapping
        self._initialized = False
        self._data = {}

    def _initialize(self):
        if self._initialized:
            return
        warnings.warn(
            "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. "
            "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.",
            FutureWarning,
        )

        for model_type, map_name in self._mapping.items():
            module_name = model_type_to_module_name(model_type)
            module = importlib.import_module(f".{module_name}", "transformers.models")
            mapping = getattr(module, map_name)
            self._data.update(mapping)

        self._initialized = True

    def __getitem__(self, key):
        self._initialize()
        return self._data[key]

    def keys(self):
        self._initialize()
        return self._data.keys()

    def values(self):
        self._initialize()
        return self._data.values()

    def items(self):
        self._initialize()
        return self._data.keys()

    def __iter__(self):
        self._initialize()
        return iter(self._data)

    def __contains__(self, item):
        self._initialize()
        return item in self._data


ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPPING_NAMES)


def _get_class_name(model_class: Union[str, List[str]]):
444
    if isinstance(model_class, (list, tuple)):
Stas Bekman's avatar
Stas Bekman committed
445
446
        return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
    return f"[`{model_class}`]"
447
448


449
450
451
452
453
def _list_model_options(indent, config_to_class=None, use_model_types=True):
    if config_to_class is None and not use_model_types:
        raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
    if use_model_types:
        if config_to_class is None:
Stas Bekman's avatar
Stas Bekman committed
454
            model_type_to_name = {model_type: f"[`{config}`]" for model_type, config in CONFIG_MAPPING_NAMES.items()}
455
456
        else:
            model_type_to_name = {
457
458
459
                model_type: _get_class_name(model_class)
                for model_type, model_class in config_to_class.items()
                if model_type in MODEL_NAMES_MAPPING
460
461
            }
        lines = [
462
            f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
463
            for model_type in sorted(model_type_to_name.keys())
464
465
        ]
    else:
466
467
468
469
470
        config_to_name = {
            CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
            for config, clas in config_to_class.items()
            if config in CONFIG_MAPPING_NAMES
        }
471
        config_to_model_name = {
472
            config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
473
474
        }
        lines = [
Stas Bekman's avatar
Stas Bekman committed
475
            f"{indent}- [`{config_name}`] configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
476
            for config_name in sorted(config_to_name.keys())
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        ]
    return "\n".join(lines)


def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True):
    def docstring_decorator(fn):
        docstrings = fn.__doc__
        lines = docstrings.split("\n")
        i = 0
        while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None:
            i += 1
        if i < len(lines):
            indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0]
            if use_model_types:
                indent = f"{indent}    "
            lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types)
            docstrings = "\n".join(lines)
        else:
            raise ValueError(
                f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current docstring is:\n{docstrings}"
            )
        fn.__doc__ = docstrings
        return fn

    return docstring_decorator


Julien Chaumond's avatar
Julien Chaumond committed
504
class AutoConfig:
Lysandre Debut's avatar
Lysandre Debut committed
505
    r"""
506
    This is a generic configuration class that will be instantiated as one of the configuration classes of the library
507
    when created with the [`~AutoConfig.from_pretrained`] class method.
508

509
    This class cannot be instantiated directly using `__init__()` (throws an error).
510
    """
511

512
    def __init__(self):
513
514
515
516
        raise EnvironmentError(
            "AutoConfig is designed to be instantiated "
            "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
        )
517

518
    @classmethod
519
520
521
522
    def for_model(cls, model_type: str, *args, **kwargs):
        if model_type in CONFIG_MAPPING:
            config_class = CONFIG_MAPPING[model_type]
            return config_class(*args, **kwargs)
523
        raise ValueError(
524
            f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
525
        )
526

527
    @classmethod
528
    @replace_list_option_in_docstrings()
529
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
530
531
        r"""
        Instantiate one of the configuration classes of the library from a pretrained model configuration.
532

Sylvain Gugger's avatar
Sylvain Gugger committed
533
534
        The configuration class to instantiate is selected based on the `model_type` property of the config object that
        is loaded, or when it's missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
535

536
        List options
Lysandre Debut's avatar
Lysandre Debut committed
537
538

        Args:
539
            pretrained_model_name_or_path (`str` or `os.PathLike`):
540
541
                Can be either:

542
543
544
545
                    - A string, the *model id* of a pretrained model configuration hosted inside a model repo on
                      huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
                      namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
                    - A path to a *directory* containing a configuration file saved using the
Sylvain Gugger's avatar
Sylvain Gugger committed
546
547
                      [`~PretrainedConfig.save_pretrained`] method, or the [`~PreTrainedModel.save_pretrained`] method,
                      e.g., `./my_model_directory/`.
548
549
550
                    - A path or url to a saved configuration JSON *file*, e.g.,
                      `./my_model_directory/configuration.json`.
            cache_dir (`str` or `os.PathLike`, *optional*):
551
552
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
553
            force_download (`bool`, *optional*, defaults to `False`):
554
555
                Whether or not to force the (re-)download the model weights and configuration files and override the
                cached versions if they exist.
556
            resume_download (`bool`, *optional*, defaults to `False`):
557
558
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
559
            proxies (`Dict[str, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
560
561
                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
562
            revision(`str`, *optional*, defaults to `"main"`):
Julien Chaumond's avatar
Julien Chaumond committed
563
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
564
                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
Julien Chaumond's avatar
Julien Chaumond committed
565
                identifier allowed by git.
566
567
            return_unused_kwargs (`bool`, *optional*, defaults to `False`):
                If `False`, then this function returns just the final configuration object.
568

Sylvain Gugger's avatar
Sylvain Gugger committed
569
570
571
                If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
                dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
                part of `kwargs` which has not been used to update `config` and is otherwise ignored.
572
            trust_remote_code (`bool`, *optional*, defaults to `False`):
573
                Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
Sylvain Gugger's avatar
Sylvain Gugger committed
574
575
                should only be set to `True` for repositories you trust and in which you have read the code, as it will
                execute code present on the Hub on your local machine.
576
            kwargs(additional keyword arguments, *optional*):
577
                The values in kwargs of any keys which are configuration attributes will be used to override the loaded
Sylvain Gugger's avatar
Sylvain Gugger committed
578
                values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
579
                by the `return_unused_kwargs` keyword parameter.
Lysandre Debut's avatar
Lysandre Debut committed
580

581
        Examples:
582

583
584
        ```python
        >>> from transformers import AutoConfig
585

586
        >>> # Download configuration from huggingface.co and cache.
Sylvain Gugger's avatar
Sylvain Gugger committed
587
        >>> config = AutoConfig.from_pretrained("bert-base-uncased")
Lysandre Debut's avatar
Lysandre Debut committed
588

589
        >>> # Download configuration from huggingface.co (user-uploaded) and cache.
Sylvain Gugger's avatar
Sylvain Gugger committed
590
        >>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")
Lysandre Debut's avatar
Lysandre Debut committed
591

592
        >>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
Sylvain Gugger's avatar
Sylvain Gugger committed
593
        >>> config = AutoConfig.from_pretrained("./test/bert_saved_model/")
594

595
        >>> # Load a specific configuration file.
Sylvain Gugger's avatar
Sylvain Gugger committed
596
        >>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")
597

598
        >>> # Change some config attributes when loading a pretrained config.
Sylvain Gugger's avatar
Sylvain Gugger committed
599
        >>> config = AutoConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
600
601
        >>> config.output_attentions
        True
Sylvain Gugger's avatar
Sylvain Gugger committed
602
603
604
605

        >>> config, unused_kwargs = AutoConfig.from_pretrained(
        ...     "bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
        ... )
606
607
        >>> config.output_attentions
        True
Sylvain Gugger's avatar
Sylvain Gugger committed
608

609
610
611
        >>> config.unused_kwargs
        {'foo': False}
        ```"""
612
        kwargs["_from_auto"] = True
613
614
        kwargs["name_or_path"] = pretrained_model_name_or_path
        trust_remote_code = kwargs.pop("trust_remote_code", False)
615
        config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
        if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
            if not trust_remote_code:
                raise ValueError(
                    f"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that repo "
                    "on your local machine. Make sure you have read the code there to avoid malicious use, then set "
                    "the option `trust_remote_code=True` to remove this error."
                )
            if kwargs.get("revision", None) is None:
                logger.warn(
                    "Explicitly passing a `revision` is encouraged when loading a configuration with custom code to "
                    "ensure no malicious code has been contributed in a newer revision."
                )
            class_ref = config_dict["auto_map"]["AutoConfig"]
            module_file, class_name = class_ref.split(".")
            config_class = get_class_from_dynamic_module(
                pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
            )
            return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
        elif "model_type" in config_dict:
635
636
637
638
639
            config_class = CONFIG_MAPPING[config_dict["model_type"]]
            return config_class.from_dict(config_dict, **kwargs)
        else:
            # Fallback: use pattern matching on the string.
            for pattern, config_class in CONFIG_MAPPING.items():
640
                if pattern in str(pretrained_model_name_or_path):
641
642
                    return config_class.from_dict(config_dict, **kwargs)

643
        raise ValueError(
644
            f"Unrecognized model in {pretrained_model_name_or_path}. "
645
            f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
646
            f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
647
        )
648
649
650
651
652
653
654

    @staticmethod
    def register(model_type, config):
        """
        Register a new configuration for this class.

        Args:
655
656
            model_type (`str`): The model type like "bert" or "gpt".
            config ([`PretrainedConfig`]): The config to register.
657
658
659
660
661
662
663
664
        """
        if issubclass(config, PretrainedConfig) and config.model_type != model_type:
            raise ValueError(
                "The config you are passing has a `model_type` attribute that is not consistent with the model type "
                f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
                "match!"
            )
        CONFIG_MAPPING.register(model_type, config)