configuration_auto.py 28.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" Auto Config class."""
16
import importlib
17
import re
18
import warnings
19
from collections import OrderedDict
20
from typing import List, Union
21

Sylvain Gugger's avatar
Sylvain Gugger committed
22
from ...configuration_utils import PretrainedConfig
23
from ...dynamic_module_utils import get_class_from_dynamic_module
24
from ...file_utils import CONFIG_NAME
25
from ...utils import logging
Aymeric Augustin's avatar
Aymeric Augustin committed
26

27

28
29
logger = logging.get_logger(__name__)

30
31
32
CONFIG_MAPPING_NAMES = OrderedDict(
    [
        # Add configs here
Tanay Mehta's avatar
Tanay Mehta committed
33
        ("poolformer", "PoolFormerConfig"),
NielsRogge's avatar
NielsRogge committed
34
        ("convnext", "ConvNextConfig"),
novice's avatar
novice committed
35
        ("yoso", "YosoConfig"),
novice's avatar
novice committed
36
        ("swin", "SwinConfig"),
NielsRogge's avatar
NielsRogge committed
37
        ("vilt", "ViltConfig"),
NielsRogge's avatar
NielsRogge committed
38
        ("vit_mae", "ViTMAEConfig"),
Li-Huai (Allan) Lin's avatar
Li-Huai (Allan) Lin committed
39
        ("realm", "RealmConfig"),
novice's avatar
novice committed
40
        ("nystromformer", "NystromformerConfig"),
Suraj Patil's avatar
Suraj Patil committed
41
        ("xglm", "XGLMConfig"),
NielsRogge's avatar
NielsRogge committed
42
        ("imagegpt", "ImageGPTConfig"),
43
        ("qdqbert", "QDQBertConfig"),
44
45
        ("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
        ("trocr", "TrOCRConfig"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
46
        ("fnet", "FNetConfig"),
NielsRogge's avatar
NielsRogge committed
47
        ("segformer", "SegformerConfig"),
Suraj Patil's avatar
Suraj Patil committed
48
        ("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
NielsRogge's avatar
NielsRogge committed
49
        ("perceiver", "PerceiverConfig"),
Stella Biderman's avatar
Stella Biderman committed
50
        ("gptj", "GPTJConfig"),
51
        ("layoutlmv2", "LayoutLMv2Config"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
52
        ("plbart", "PLBartConfig"),
53
54
55
56
57
58
59
60
61
62
63
64
        ("beit", "BeitConfig"),
        ("rembert", "RemBertConfig"),
        ("visual_bert", "VisualBertConfig"),
        ("canine", "CanineConfig"),
        ("roformer", "RoFormerConfig"),
        ("clip", "CLIPConfig"),
        ("bigbird_pegasus", "BigBirdPegasusConfig"),
        ("deit", "DeiTConfig"),
        ("luke", "LukeConfig"),
        ("detr", "DetrConfig"),
        ("gpt_neo", "GPTNeoConfig"),
        ("big_bird", "BigBirdConfig"),
65
        ("speech_to_text_2", "Speech2Text2Config"),
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        ("speech_to_text", "Speech2TextConfig"),
        ("vit", "ViTConfig"),
        ("wav2vec2", "Wav2Vec2Config"),
        ("m2m_100", "M2M100Config"),
        ("convbert", "ConvBertConfig"),
        ("led", "LEDConfig"),
        ("blenderbot-small", "BlenderbotSmallConfig"),
        ("retribert", "RetriBertConfig"),
        ("ibert", "IBertConfig"),
        ("mt5", "MT5Config"),
        ("t5", "T5Config"),
        ("mobilebert", "MobileBertConfig"),
        ("distilbert", "DistilBertConfig"),
        ("albert", "AlbertConfig"),
        ("bert-generation", "BertGenerationConfig"),
        ("camembert", "CamembertConfig"),
82
        ("xlm-roberta-xl", "XLMRobertaXLConfig"),
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        ("xlm-roberta", "XLMRobertaConfig"),
        ("pegasus", "PegasusConfig"),
        ("marian", "MarianConfig"),
        ("mbart", "MBartConfig"),
        ("megatron-bert", "MegatronBertConfig"),
        ("mpnet", "MPNetConfig"),
        ("bart", "BartConfig"),
        ("blenderbot", "BlenderbotConfig"),
        ("reformer", "ReformerConfig"),
        ("longformer", "LongformerConfig"),
        ("roberta", "RobertaConfig"),
        ("deberta-v2", "DebertaV2Config"),
        ("deberta", "DebertaConfig"),
        ("flaubert", "FlaubertConfig"),
        ("fsmt", "FSMTConfig"),
        ("squeezebert", "SqueezeBertConfig"),
        ("hubert", "HubertConfig"),
        ("bert", "BertConfig"),
        ("openai-gpt", "OpenAIGPTConfig"),
        ("gpt2", "GPT2Config"),
        ("transfo-xl", "TransfoXLConfig"),
        ("xlnet", "XLNetConfig"),
        ("xlm-prophetnet", "XLMProphetNetConfig"),
        ("prophetnet", "ProphetNetConfig"),
        ("xlm", "XLMConfig"),
        ("ctrl", "CTRLConfig"),
        ("electra", "ElectraConfig"),
110
        ("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
111
112
113
114
115
116
117
        ("encoder-decoder", "EncoderDecoderConfig"),
        ("funnel", "FunnelConfig"),
        ("lxmert", "LxmertConfig"),
        ("dpr", "DPRConfig"),
        ("layoutlm", "LayoutLMConfig"),
        ("rag", "RagConfig"),
        ("tapas", "TapasConfig"),
Ori Ram's avatar
Ori Ram committed
118
        ("splinter", "SplinterConfig"),
119
120
        ("sew-d", "SEWDConfig"),
        ("sew", "SEWConfig"),
121
122
        ("unispeech-sat", "UniSpeechSatConfig"),
        ("unispeech", "UniSpeechConfig"),
Patrick von Platen's avatar
Patrick von Platen committed
123
        ("wavlm", "WavLMConfig"),
124
125
    ]
)
126

127
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
128
    [
129
        # Add archive maps here
Tanay Mehta's avatar
Tanay Mehta committed
130
        ("poolformer", "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
131
        ("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
novice's avatar
novice committed
132
        ("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
novice's avatar
novice committed
133
        ("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
134
        ("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
135
        ("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Li-Huai (Allan) Lin's avatar
Li-Huai (Allan) Lin committed
136
        ("realm", "REALM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
novice's avatar
novice committed
137
        ("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Suraj Patil's avatar
Suraj Patil committed
138
        ("xglm", "XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
139
        ("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
140
        ("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
141
        ("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
142
        ("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
143
        ("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
144
        ("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Stella Biderman's avatar
Stella Biderman committed
145
        ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
146
        ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
147
        ("plbart", "PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
162
        ("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        ("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Ori Ram's avatar
Ori Ram committed
203
        ("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
204
205
        ("sew-d", "SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("sew", "SEW_PRETRAINED_CONFIG_ARCHIVE_MAP"),
206
207
        ("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"),
208
209
210
211
212
    ]
)

MODEL_NAMES_MAPPING = OrderedDict(
    [
213
        # Add full (and cased) model names here
Tanay Mehta's avatar
Tanay Mehta committed
214
        ("poolformer", "PoolFormer"),
NielsRogge's avatar
NielsRogge committed
215
        ("convnext", "ConvNext"),
novice's avatar
novice committed
216
        ("yoso", "YOSO"),
novice's avatar
novice committed
217
        ("swin", "Swin"),
NielsRogge's avatar
NielsRogge committed
218
        ("vilt", "ViLT"),
NielsRogge's avatar
NielsRogge committed
219
        ("vit_mae", "ViTMAE"),
Li-Huai (Allan) Lin's avatar
Li-Huai (Allan) Lin committed
220
        ("realm", "Realm"),
novice's avatar
novice committed
221
        ("nystromformer", "Nystromformer"),
Suraj Patil's avatar
Suraj Patil committed
222
        ("xglm", "XGLM"),
NielsRogge's avatar
NielsRogge committed
223
        ("imagegpt", "ImageGPT"),
224
        ("qdqbert", "QDQBert"),
225
226
        ("vision-encoder-decoder", "Vision Encoder decoder"),
        ("trocr", "TrOCR"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
227
        ("fnet", "FNet"),
NielsRogge's avatar
NielsRogge committed
228
        ("segformer", "SegFormer"),
Suraj Patil's avatar
Suraj Patil committed
229
        ("vision-text-dual-encoder", "VisionTextDualEncoder"),
NielsRogge's avatar
NielsRogge committed
230
        ("perceiver", "Perceiver"),
Stella Biderman's avatar
Stella Biderman committed
231
        ("gptj", "GPT-J"),
NielsRogge's avatar
NielsRogge committed
232
        ("beit", "BEiT"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
233
        ("plbart", "PLBart"),
234
        ("rembert", "RemBERT"),
235
        ("layoutlmv2", "LayoutLMv2"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
236
        ("visual_bert", "VisualBert"),
NielsRogge's avatar
NielsRogge committed
237
        ("canine", "Canine"),
238
        ("roformer", "RoFormer"),
Suraj Patil's avatar
Suraj Patil committed
239
        ("clip", "CLIP"),
Vasudev Gupta's avatar
Vasudev Gupta committed
240
        ("bigbird_pegasus", "BigBirdPegasus"),
NielsRogge's avatar
NielsRogge committed
241
        ("deit", "DeiT"),
NielsRogge's avatar
NielsRogge committed
242
        ("luke", "LUKE"),
NielsRogge's avatar
NielsRogge committed
243
        ("detr", "DETR"),
Suraj Patil's avatar
Suraj Patil committed
244
        ("gpt_neo", "GPT Neo"),
Vasudev Gupta's avatar
Vasudev Gupta committed
245
        ("big_bird", "BigBird"),
246
        ("speech_to_text_2", "Speech2Text2"),
Suraj Patil's avatar
Suraj Patil committed
247
        ("speech_to_text", "Speech2Text"),
248
        ("vit", "ViT"),
Patrick von Platen's avatar
Patrick von Platen committed
249
        ("wav2vec2", "Wav2Vec2"),
Suraj Patil's avatar
Suraj Patil committed
250
        ("m2m_100", "M2M100"),
abhishek thakur's avatar
abhishek thakur committed
251
        ("convbert", "ConvBERT"),
Patrick von Platen's avatar
Patrick von Platen committed
252
        ("led", "LED"),
253
        ("blenderbot-small", "BlenderbotSmall"),
254
        ("retribert", "RetriBERT"),
Sehoon Kim's avatar
Sehoon Kim committed
255
        ("ibert", "I-BERT"),
256
257
258
259
260
261
262
        ("t5", "T5"),
        ("mobilebert", "MobileBERT"),
        ("distilbert", "DistilBERT"),
        ("albert", "ALBERT"),
        ("bert-generation", "Bert Generation"),
        ("camembert", "CamemBERT"),
        ("xlm-roberta", "XLM-RoBERTa"),
263
        ("xlm-roberta-xl", "XLM-RoBERTa-XL"),
264
        ("pegasus", "Pegasus"),
Sam Shleifer's avatar
Sam Shleifer committed
265
        ("blenderbot", "Blenderbot"),
266
267
        ("marian", "Marian"),
        ("mbart", "mBART"),
268
        ("megatron-bert", "MegatronBert"),
269
270
271
272
273
        ("bart", "BART"),
        ("reformer", "Reformer"),
        ("longformer", "Longformer"),
        ("roberta", "RoBERTa"),
        ("flaubert", "FlauBERT"),
274
        ("fsmt", "FairSeq Machine-Translation"),
275
        ("squeezebert", "SqueezeBERT"),
276
277
278
279
280
281
282
283
284
        ("bert", "BERT"),
        ("openai-gpt", "OpenAI GPT"),
        ("gpt2", "OpenAI GPT-2"),
        ("transfo-xl", "Transformer-XL"),
        ("xlnet", "XLNet"),
        ("xlm", "XLM"),
        ("ctrl", "CTRL"),
        ("electra", "ELECTRA"),
        ("encoder-decoder", "Encoder decoder"),
285
        ("speech-encoder-decoder", "Speech Encoder decoder"),
286
        ("vision-encoder-decoder", "Vision Encoder decoder"),
287
288
        ("funnel", "Funnel Transformer"),
        ("lxmert", "LXMERT"),
289
        ("deberta-v2", "DeBERTa-v2"),
Pengcheng He's avatar
Pengcheng He committed
290
        ("deberta", "DeBERTa"),
Minghao Li's avatar
Minghao Li committed
291
        ("layoutlm", "LayoutLM"),
Ola Piktus's avatar
Ola Piktus committed
292
293
        ("dpr", "DPR"),
        ("rag", "RAG"),
Weizhen's avatar
Weizhen committed
294
295
        ("xlm-prophetnet", "XLMProphetNet"),
        ("prophetnet", "ProphetNet"),
Patrick von Platen's avatar
Patrick von Platen committed
296
        ("mt5", "mT5"),
StillKeepTry's avatar
StillKeepTry committed
297
        ("mpnet", "MPNet"),
NielsRogge's avatar
NielsRogge committed
298
        ("tapas", "TAPAS"),
Patrick von Platen's avatar
Patrick von Platen committed
299
        ("hubert", "Hubert"),
300
301
        ("barthez", "BARThez"),
        ("phobert", "PhoBERT"),
302
        ("bartpho", "BARTpho"),
303
304
305
306
307
        ("cpm", "CPM"),
        ("bertweet", "Bertweet"),
        ("bert-japanese", "BertJapanese"),
        ("byt5", "ByT5"),
        ("mbart50", "mBART-50"),
Ori Ram's avatar
Ori Ram committed
308
        ("splinter", "Splinter"),
309
310
        ("sew-d", "SEW-D"),
        ("sew", "SEW"),
311
312
        ("unispeech-sat", "UniSpeechSat"),
        ("unispeech", "UniSpeech"),
Patrick von Platen's avatar
Patrick von Platen committed
313
        ("wavlm", "WavLM"),
314
315
316
317
318
319
320
321
322
323
        ("bort", "BORT"),
        ("dialogpt", "DialoGPT"),
        ("xls_r", "XLS-R"),
        ("t5v1.1", "T5v1.1"),
        ("herbert", "HerBERT"),
        ("wav2vec2_phoneme", "Wav2Vec2Phoneme"),
        ("megatron_gpt2", "MegatronGPT2"),
        ("xlsr_wav2vec2", "XLSR-Wav2Vec2"),
        ("mluke", "mLUKE"),
        ("layoutxlm", "LayoutXLM"),
324
325
326
    ]
)

327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict([("openai-gpt", "openai")])


def model_type_to_module_name(key):
    """Converts a config key to the corresponding module."""
    # Special treatment
    if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
        return SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]

    return key.replace("-", "_")


def config_class_to_model_type(config):
    """Converts a config class name to the corresponding model type"""
    for key, cls in CONFIG_MAPPING_NAMES.items():
        if cls == config:
            return key
    return None


class _LazyConfigMapping(OrderedDict):
    """
    A dictionary that lazily load its values when they are requested.
    """

    def __init__(self, mapping):
        self._mapping = mapping
354
        self._extra_content = {}
355
356
357
        self._modules = {}

    def __getitem__(self, key):
358
359
        if key in self._extra_content:
            return self._extra_content[key]
360
361
362
363
364
365
366
367
368
        if key not in self._mapping:
            raise KeyError(key)
        value = self._mapping[key]
        module_name = model_type_to_module_name(key)
        if module_name not in self._modules:
            self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
        return getattr(self._modules[module_name], value)

    def keys(self):
369
        return list(self._mapping.keys()) + list(self._extra_content.keys())
370
371

    def values(self):
372
        return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
373

374
    def items(self):
375
        return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
376
377

    def __iter__(self):
378
        return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
379
380

    def __contains__(self, item):
381
382
383
384
385
386
387
388
389
        return item in self._mapping or item in self._extra_content

    def register(self, key, value):
        """
        Register a new configuration in this mapping.
        """
        if key in self._mapping.keys():
            raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
        self._extra_content[key] = value
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454


CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)


class _LazyLoadAllMappings(OrderedDict):
    """
    A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
    etc.)

    Args:
        mapping: The mapping to load.
    """

    def __init__(self, mapping):
        self._mapping = mapping
        self._initialized = False
        self._data = {}

    def _initialize(self):
        if self._initialized:
            return
        warnings.warn(
            "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. "
            "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.",
            FutureWarning,
        )

        for model_type, map_name in self._mapping.items():
            module_name = model_type_to_module_name(model_type)
            module = importlib.import_module(f".{module_name}", "transformers.models")
            mapping = getattr(module, map_name)
            self._data.update(mapping)

        self._initialized = True

    def __getitem__(self, key):
        self._initialize()
        return self._data[key]

    def keys(self):
        self._initialize()
        return self._data.keys()

    def values(self):
        self._initialize()
        return self._data.values()

    def items(self):
        self._initialize()
        return self._data.keys()

    def __iter__(self):
        self._initialize()
        return iter(self._data)

    def __contains__(self, item):
        self._initialize()
        return item in self._data


ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPPING_NAMES)


def _get_class_name(model_class: Union[str, List[str]]):
455
    if isinstance(model_class, (list, tuple)):
Stas Bekman's avatar
Stas Bekman committed
456
457
        return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
    return f"[`{model_class}`]"
458
459


460
461
462
463
464
def _list_model_options(indent, config_to_class=None, use_model_types=True):
    if config_to_class is None and not use_model_types:
        raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
    if use_model_types:
        if config_to_class is None:
Stas Bekman's avatar
Stas Bekman committed
465
            model_type_to_name = {model_type: f"[`{config}`]" for model_type, config in CONFIG_MAPPING_NAMES.items()}
466
467
        else:
            model_type_to_name = {
468
469
470
                model_type: _get_class_name(model_class)
                for model_type, model_class in config_to_class.items()
                if model_type in MODEL_NAMES_MAPPING
471
472
            }
        lines = [
473
            f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
474
            for model_type in sorted(model_type_to_name.keys())
475
476
        ]
    else:
477
478
479
480
481
        config_to_name = {
            CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
            for config, clas in config_to_class.items()
            if config in CONFIG_MAPPING_NAMES
        }
482
        config_to_model_name = {
483
            config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
484
485
        }
        lines = [
Stas Bekman's avatar
Stas Bekman committed
486
            f"{indent}- [`{config_name}`] configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
487
            for config_name in sorted(config_to_name.keys())
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
        ]
    return "\n".join(lines)


def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True):
    def docstring_decorator(fn):
        docstrings = fn.__doc__
        lines = docstrings.split("\n")
        i = 0
        while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None:
            i += 1
        if i < len(lines):
            indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0]
            if use_model_types:
                indent = f"{indent}    "
            lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types)
            docstrings = "\n".join(lines)
        else:
            raise ValueError(
                f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current docstring is:\n{docstrings}"
            )
        fn.__doc__ = docstrings
        return fn

    return docstring_decorator


Julien Chaumond's avatar
Julien Chaumond committed
515
class AutoConfig:
Lysandre Debut's avatar
Lysandre Debut committed
516
    r"""
517
    This is a generic configuration class that will be instantiated as one of the configuration classes of the library
518
    when created with the [`~AutoConfig.from_pretrained`] class method.
519

520
    This class cannot be instantiated directly using `__init__()` (throws an error).
521
    """
522

523
    def __init__(self):
524
525
526
527
        raise EnvironmentError(
            "AutoConfig is designed to be instantiated "
            "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
        )
528

529
    @classmethod
530
531
532
533
    def for_model(cls, model_type: str, *args, **kwargs):
        if model_type in CONFIG_MAPPING:
            config_class = CONFIG_MAPPING[model_type]
            return config_class(*args, **kwargs)
534
        raise ValueError(
535
            f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
536
        )
537

538
    @classmethod
539
    @replace_list_option_in_docstrings()
540
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
541
542
        r"""
        Instantiate one of the configuration classes of the library from a pretrained model configuration.
543

Sylvain Gugger's avatar
Sylvain Gugger committed
544
545
        The configuration class to instantiate is selected based on the `model_type` property of the config object that
        is loaded, or when it's missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
546

547
        List options
Lysandre Debut's avatar
Lysandre Debut committed
548
549

        Args:
550
            pretrained_model_name_or_path (`str` or `os.PathLike`):
551
552
                Can be either:

553
554
555
556
                    - A string, the *model id* of a pretrained model configuration hosted inside a model repo on
                      huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
                      namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
                    - A path to a *directory* containing a configuration file saved using the
Sylvain Gugger's avatar
Sylvain Gugger committed
557
558
                      [`~PretrainedConfig.save_pretrained`] method, or the [`~PreTrainedModel.save_pretrained`] method,
                      e.g., `./my_model_directory/`.
559
560
561
                    - A path or url to a saved configuration JSON *file*, e.g.,
                      `./my_model_directory/configuration.json`.
            cache_dir (`str` or `os.PathLike`, *optional*):
562
563
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
564
            force_download (`bool`, *optional*, defaults to `False`):
565
566
                Whether or not to force the (re-)download the model weights and configuration files and override the
                cached versions if they exist.
567
            resume_download (`bool`, *optional*, defaults to `False`):
568
569
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
570
            proxies (`Dict[str, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
571
572
                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
573
            revision(`str`, *optional*, defaults to `"main"`):
Julien Chaumond's avatar
Julien Chaumond committed
574
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
575
                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
Julien Chaumond's avatar
Julien Chaumond committed
576
                identifier allowed by git.
577
578
            return_unused_kwargs (`bool`, *optional*, defaults to `False`):
                If `False`, then this function returns just the final configuration object.
579

Sylvain Gugger's avatar
Sylvain Gugger committed
580
581
582
                If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
                dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
                part of `kwargs` which has not been used to update `config` and is otherwise ignored.
583
            trust_remote_code (`bool`, *optional*, defaults to `False`):
584
                Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
Sylvain Gugger's avatar
Sylvain Gugger committed
585
586
                should only be set to `True` for repositories you trust and in which you have read the code, as it will
                execute code present on the Hub on your local machine.
587
            kwargs(additional keyword arguments, *optional*):
588
                The values in kwargs of any keys which are configuration attributes will be used to override the loaded
Sylvain Gugger's avatar
Sylvain Gugger committed
589
                values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
590
                by the `return_unused_kwargs` keyword parameter.
Lysandre Debut's avatar
Lysandre Debut committed
591

592
        Examples:
593

594
595
        ```python
        >>> from transformers import AutoConfig
596

597
        >>> # Download configuration from huggingface.co and cache.
Sylvain Gugger's avatar
Sylvain Gugger committed
598
        >>> config = AutoConfig.from_pretrained("bert-base-uncased")
Lysandre Debut's avatar
Lysandre Debut committed
599

600
        >>> # Download configuration from huggingface.co (user-uploaded) and cache.
Sylvain Gugger's avatar
Sylvain Gugger committed
601
        >>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")
Lysandre Debut's avatar
Lysandre Debut committed
602

603
        >>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
Sylvain Gugger's avatar
Sylvain Gugger committed
604
        >>> config = AutoConfig.from_pretrained("./test/bert_saved_model/")
605

606
        >>> # Load a specific configuration file.
Sylvain Gugger's avatar
Sylvain Gugger committed
607
        >>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")
608

609
        >>> # Change some config attributes when loading a pretrained config.
Sylvain Gugger's avatar
Sylvain Gugger committed
610
        >>> config = AutoConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
611
612
        >>> config.output_attentions
        True
Sylvain Gugger's avatar
Sylvain Gugger committed
613
614
615
616

        >>> config, unused_kwargs = AutoConfig.from_pretrained(
        ...     "bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
        ... )
617
618
        >>> config.output_attentions
        True
Sylvain Gugger's avatar
Sylvain Gugger committed
619

620
621
622
        >>> config.unused_kwargs
        {'foo': False}
        ```"""
623
        kwargs["_from_auto"] = True
624
625
        kwargs["name_or_path"] = pretrained_model_name_or_path
        trust_remote_code = kwargs.pop("trust_remote_code", False)
626
        config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
627
628
629
630
631
632
633
634
        if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
            if not trust_remote_code:
                raise ValueError(
                    f"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that repo "
                    "on your local machine. Make sure you have read the code there to avoid malicious use, then set "
                    "the option `trust_remote_code=True` to remove this error."
                )
            if kwargs.get("revision", None) is None:
635
                logger.warning(
636
637
638
639
640
641
642
643
644
645
                    "Explicitly passing a `revision` is encouraged when loading a configuration with custom code to "
                    "ensure no malicious code has been contributed in a newer revision."
                )
            class_ref = config_dict["auto_map"]["AutoConfig"]
            module_file, class_name = class_ref.split(".")
            config_class = get_class_from_dynamic_module(
                pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
            )
            return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
        elif "model_type" in config_dict:
646
647
648
649
650
            config_class = CONFIG_MAPPING[config_dict["model_type"]]
            return config_class.from_dict(config_dict, **kwargs)
        else:
            # Fallback: use pattern matching on the string.
            for pattern, config_class in CONFIG_MAPPING.items():
651
                if pattern in str(pretrained_model_name_or_path):
652
653
                    return config_class.from_dict(config_dict, **kwargs)

654
        raise ValueError(
655
            f"Unrecognized model in {pretrained_model_name_or_path}. "
656
            f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
657
            f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
658
        )
659
660
661
662
663
664
665

    @staticmethod
    def register(model_type, config):
        """
        Register a new configuration for this class.

        Args:
666
667
            model_type (`str`): The model type like "bert" or "gpt".
            config ([`PretrainedConfig`]): The config to register.
668
669
670
671
672
673
674
675
        """
        if issubclass(config, PretrainedConfig) and config.model_type != model_type:
            raise ValueError(
                "The config you are passing has a `model_type` attribute that is not consistent with the model type "
                f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
                "match!"
            )
        CONFIG_MAPPING.register(model_type, config)