configuration_auto.py 26.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
""" Auto Config class. """
16
import importlib
17
import re
18
import warnings
19
from collections import OrderedDict
20
from typing import List, Union
21

Sylvain Gugger's avatar
Sylvain Gugger committed
22
from ...configuration_utils import PretrainedConfig
23
from ...file_utils import CONFIG_NAME
24
25
from ...utils import logging
from .dynamic import get_class_from_dynamic_module
Aymeric Augustin's avatar
Aymeric Augustin committed
26

27

28
29
logger = logging.get_logger(__name__)

30
31
32
CONFIG_MAPPING_NAMES = OrderedDict(
    [
        # Add configs here
NielsRogge's avatar
NielsRogge committed
33
        ("imagegpt", "ImageGPTConfig"),
34
        ("qdqbert", "QDQBertConfig"),
35
36
        ("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
        ("trocr", "TrOCRConfig"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
37
        ("fnet", "FNetConfig"),
NielsRogge's avatar
NielsRogge committed
38
        ("segformer", "SegformerConfig"),
Stella Biderman's avatar
Stella Biderman committed
39
        ("gptj", "GPTJConfig"),
40
        ("layoutlmv2", "LayoutLMv2Config"),
41
42
43
44
45
46
47
48
49
50
51
52
        ("beit", "BeitConfig"),
        ("rembert", "RemBertConfig"),
        ("visual_bert", "VisualBertConfig"),
        ("canine", "CanineConfig"),
        ("roformer", "RoFormerConfig"),
        ("clip", "CLIPConfig"),
        ("bigbird_pegasus", "BigBirdPegasusConfig"),
        ("deit", "DeiTConfig"),
        ("luke", "LukeConfig"),
        ("detr", "DetrConfig"),
        ("gpt_neo", "GPTNeoConfig"),
        ("big_bird", "BigBirdConfig"),
53
        ("speech_to_text_2", "Speech2Text2Config"),
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        ("speech_to_text", "Speech2TextConfig"),
        ("vit", "ViTConfig"),
        ("wav2vec2", "Wav2Vec2Config"),
        ("m2m_100", "M2M100Config"),
        ("convbert", "ConvBertConfig"),
        ("led", "LEDConfig"),
        ("blenderbot-small", "BlenderbotSmallConfig"),
        ("retribert", "RetriBertConfig"),
        ("ibert", "IBertConfig"),
        ("mt5", "MT5Config"),
        ("t5", "T5Config"),
        ("mobilebert", "MobileBertConfig"),
        ("distilbert", "DistilBertConfig"),
        ("albert", "AlbertConfig"),
        ("bert-generation", "BertGenerationConfig"),
        ("camembert", "CamembertConfig"),
        ("xlm-roberta", "XLMRobertaConfig"),
        ("pegasus", "PegasusConfig"),
        ("marian", "MarianConfig"),
        ("mbart", "MBartConfig"),
        ("megatron-bert", "MegatronBertConfig"),
        ("mpnet", "MPNetConfig"),
        ("bart", "BartConfig"),
        ("blenderbot", "BlenderbotConfig"),
        ("reformer", "ReformerConfig"),
        ("longformer", "LongformerConfig"),
        ("roberta", "RobertaConfig"),
        ("deberta-v2", "DebertaV2Config"),
        ("deberta", "DebertaConfig"),
        ("flaubert", "FlaubertConfig"),
        ("fsmt", "FSMTConfig"),
        ("squeezebert", "SqueezeBertConfig"),
        ("hubert", "HubertConfig"),
        ("bert", "BertConfig"),
        ("openai-gpt", "OpenAIGPTConfig"),
        ("gpt2", "GPT2Config"),
        ("transfo-xl", "TransfoXLConfig"),
        ("xlnet", "XLNetConfig"),
        ("xlm-prophetnet", "XLMProphetNetConfig"),
        ("prophetnet", "ProphetNetConfig"),
        ("xlm", "XLMConfig"),
        ("ctrl", "CTRLConfig"),
        ("electra", "ElectraConfig"),
97
        ("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
98
99
100
101
102
103
104
        ("encoder-decoder", "EncoderDecoderConfig"),
        ("funnel", "FunnelConfig"),
        ("lxmert", "LxmertConfig"),
        ("dpr", "DPRConfig"),
        ("layoutlm", "LayoutLMConfig"),
        ("rag", "RagConfig"),
        ("tapas", "TapasConfig"),
Ori Ram's avatar
Ori Ram committed
105
        ("splinter", "SplinterConfig"),
106
107
        ("sew-d", "SEWDConfig"),
        ("sew", "SEWConfig"),
108
109
        ("unispeech-sat", "UniSpeechSatConfig"),
        ("unispeech", "UniSpeechConfig"),
110
111
    ]
)
112

113
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
114
    [
115
        # Add archive maps here
NielsRogge's avatar
NielsRogge committed
116
        ("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
117
        ("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
118
        ("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
119
        ("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge's avatar
NielsRogge committed
120
        ("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Stella Biderman's avatar
Stella Biderman committed
121
        ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
122
        ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
137
        ("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        ("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Ori Ram's avatar
Ori Ram committed
178
        ("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
179
180
        ("sew-d", "SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("sew", "SEW_PRETRAINED_CONFIG_ARCHIVE_MAP"),
181
182
        ("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
        ("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"),
183
184
185
186
187
    ]
)

MODEL_NAMES_MAPPING = OrderedDict(
    [
188
        # Add full (and cased) model names here
NielsRogge's avatar
NielsRogge committed
189
        ("imagegpt", "ImageGPT"),
190
        ("qdqbert", "QDQBert"),
191
192
        ("vision-encoder-decoder", "Vision Encoder decoder"),
        ("trocr", "TrOCR"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
193
        ("fnet", "FNet"),
NielsRogge's avatar
NielsRogge committed
194
        ("segformer", "SegFormer"),
Stella Biderman's avatar
Stella Biderman committed
195
        ("gptj", "GPT-J"),
NielsRogge's avatar
NielsRogge committed
196
        ("beit", "BEiT"),
197
        ("rembert", "RemBERT"),
198
        ("layoutlmv2", "LayoutLMv2"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
199
        ("visual_bert", "VisualBert"),
NielsRogge's avatar
NielsRogge committed
200
        ("canine", "Canine"),
201
        ("roformer", "RoFormer"),
Suraj Patil's avatar
Suraj Patil committed
202
        ("clip", "CLIP"),
Vasudev Gupta's avatar
Vasudev Gupta committed
203
        ("bigbird_pegasus", "BigBirdPegasus"),
NielsRogge's avatar
NielsRogge committed
204
        ("deit", "DeiT"),
NielsRogge's avatar
NielsRogge committed
205
        ("luke", "LUKE"),
NielsRogge's avatar
NielsRogge committed
206
        ("detr", "DETR"),
Suraj Patil's avatar
Suraj Patil committed
207
        ("gpt_neo", "GPT Neo"),
Vasudev Gupta's avatar
Vasudev Gupta committed
208
        ("big_bird", "BigBird"),
209
        ("speech_to_text_2", "Speech2Text2"),
Suraj Patil's avatar
Suraj Patil committed
210
        ("speech_to_text", "Speech2Text"),
211
        ("vit", "ViT"),
Patrick von Platen's avatar
Patrick von Platen committed
212
        ("wav2vec2", "Wav2Vec2"),
Suraj Patil's avatar
Suraj Patil committed
213
        ("m2m_100", "M2M100"),
abhishek thakur's avatar
abhishek thakur committed
214
        ("convbert", "ConvBERT"),
Patrick von Platen's avatar
Patrick von Platen committed
215
        ("led", "LED"),
216
        ("blenderbot-small", "BlenderbotSmall"),
217
        ("retribert", "RetriBERT"),
Sehoon Kim's avatar
Sehoon Kim committed
218
        ("ibert", "I-BERT"),
219
220
221
222
223
224
225
226
        ("t5", "T5"),
        ("mobilebert", "MobileBERT"),
        ("distilbert", "DistilBERT"),
        ("albert", "ALBERT"),
        ("bert-generation", "Bert Generation"),
        ("camembert", "CamemBERT"),
        ("xlm-roberta", "XLM-RoBERTa"),
        ("pegasus", "Pegasus"),
Sam Shleifer's avatar
Sam Shleifer committed
227
        ("blenderbot", "Blenderbot"),
228
229
        ("marian", "Marian"),
        ("mbart", "mBART"),
230
        ("megatron-bert", "MegatronBert"),
231
232
233
234
235
        ("bart", "BART"),
        ("reformer", "Reformer"),
        ("longformer", "Longformer"),
        ("roberta", "RoBERTa"),
        ("flaubert", "FlauBERT"),
236
        ("fsmt", "FairSeq Machine-Translation"),
237
        ("squeezebert", "SqueezeBERT"),
238
239
240
241
242
243
244
245
246
        ("bert", "BERT"),
        ("openai-gpt", "OpenAI GPT"),
        ("gpt2", "OpenAI GPT-2"),
        ("transfo-xl", "Transformer-XL"),
        ("xlnet", "XLNet"),
        ("xlm", "XLM"),
        ("ctrl", "CTRL"),
        ("electra", "ELECTRA"),
        ("encoder-decoder", "Encoder decoder"),
247
        ("speech-encoder-decoder", "Speech Encoder decoder"),
248
        ("vision-encoder-decoder", "Vision Encoder decoder"),
249
250
        ("funnel", "Funnel Transformer"),
        ("lxmert", "LXMERT"),
251
        ("deberta-v2", "DeBERTa-v2"),
Pengcheng He's avatar
Pengcheng He committed
252
        ("deberta", "DeBERTa"),
Minghao Li's avatar
Minghao Li committed
253
        ("layoutlm", "LayoutLM"),
Ola Piktus's avatar
Ola Piktus committed
254
255
        ("dpr", "DPR"),
        ("rag", "RAG"),
Weizhen's avatar
Weizhen committed
256
257
        ("xlm-prophetnet", "XLMProphetNet"),
        ("prophetnet", "ProphetNet"),
Patrick von Platen's avatar
Patrick von Platen committed
258
        ("mt5", "mT5"),
StillKeepTry's avatar
StillKeepTry committed
259
        ("mpnet", "MPNet"),
NielsRogge's avatar
NielsRogge committed
260
        ("tapas", "TAPAS"),
Patrick von Platen's avatar
Patrick von Platen committed
261
        ("hubert", "Hubert"),
262
263
        ("barthez", "BARThez"),
        ("phobert", "PhoBERT"),
264
        ("bartpho", "BARTpho"),
265
266
267
268
269
        ("cpm", "CPM"),
        ("bertweet", "Bertweet"),
        ("bert-japanese", "BertJapanese"),
        ("byt5", "ByT5"),
        ("mbart50", "mBART-50"),
Ori Ram's avatar
Ori Ram committed
270
        ("splinter", "Splinter"),
271
272
        ("sew-d", "SEW-D"),
        ("sew", "SEW"),
273
274
        ("unispeech-sat", "UniSpeechSat"),
        ("unispeech", "UniSpeech"),
275
276
277
    ]
)

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict([("openai-gpt", "openai")])


def model_type_to_module_name(key):
    """Converts a config key to the corresponding module."""
    # Special treatment
    if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
        return SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]

    return key.replace("-", "_")


def config_class_to_model_type(config):
    """Converts a config class name to the corresponding model type"""
    for key, cls in CONFIG_MAPPING_NAMES.items():
        if cls == config:
            return key
    return None


class _LazyConfigMapping(OrderedDict):
    """
    A dictionary that lazily load its values when they are requested.
    """

    def __init__(self, mapping):
        self._mapping = mapping
305
        self._extra_content = {}
306
307
308
        self._modules = {}

    def __getitem__(self, key):
309
310
        if key in self._extra_content:
            return self._extra_content[key]
311
312
313
314
315
316
317
318
319
        if key not in self._mapping:
            raise KeyError(key)
        value = self._mapping[key]
        module_name = model_type_to_module_name(key)
        if module_name not in self._modules:
            self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
        return getattr(self._modules[module_name], value)

    def keys(self):
320
        return list(self._mapping.keys()) + list(self._extra_content.keys())
321
322

    def values(self):
323
        return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
324

325
    def items(self):
326
        return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
327
328

    def __iter__(self):
329
        return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
330
331

    def __contains__(self, item):
332
333
334
335
336
337
338
339
340
        return item in self._mapping or item in self._extra_content

    def register(self, key, value):
        """
        Register a new configuration in this mapping.
        """
        if key in self._mapping.keys():
            raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
        self._extra_content[key] = value
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405


CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)


class _LazyLoadAllMappings(OrderedDict):
    """
    A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
    etc.)

    Args:
        mapping: The mapping to load.
    """

    def __init__(self, mapping):
        self._mapping = mapping
        self._initialized = False
        self._data = {}

    def _initialize(self):
        if self._initialized:
            return
        warnings.warn(
            "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. "
            "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.",
            FutureWarning,
        )

        for model_type, map_name in self._mapping.items():
            module_name = model_type_to_module_name(model_type)
            module = importlib.import_module(f".{module_name}", "transformers.models")
            mapping = getattr(module, map_name)
            self._data.update(mapping)

        self._initialized = True

    def __getitem__(self, key):
        self._initialize()
        return self._data[key]

    def keys(self):
        self._initialize()
        return self._data.keys()

    def values(self):
        self._initialize()
        return self._data.values()

    def items(self):
        self._initialize()
        return self._data.keys()

    def __iter__(self):
        self._initialize()
        return iter(self._data)

    def __contains__(self, item):
        self._initialize()
        return item in self._data


ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPPING_NAMES)


def _get_class_name(model_class: Union[str, List[str]]):
406
    if isinstance(model_class, (list, tuple)):
407
408
        return " or ".join([f":class:`~transformers.{c}`" for c in model_class if c is not None])
    return f":class:`~transformers.{model_class}`"
409
410


411
412
413
414
415
def _list_model_options(indent, config_to_class=None, use_model_types=True):
    if config_to_class is None and not use_model_types:
        raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
    if use_model_types:
        if config_to_class is None:
416
            model_type_to_name = {
417
                model_type: f":class:`~transformers.{config}`" for model_type, config in CONFIG_MAPPING_NAMES.items()
418
            }
419
420
        else:
            model_type_to_name = {
421
422
423
                model_type: _get_class_name(model_class)
                for model_type, model_class in config_to_class.items()
                if model_type in MODEL_NAMES_MAPPING
424
425
            }
        lines = [
426
            f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
427
            for model_type in sorted(model_type_to_name.keys())
428
429
        ]
    else:
430
431
432
433
434
        config_to_name = {
            CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
            for config, clas in config_to_class.items()
            if config in CONFIG_MAPPING_NAMES
        }
435
        config_to_model_name = {
436
            config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
437
438
        }
        lines = [
439
            f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
440
            for config_name in sorted(config_to_name.keys())
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        ]
    return "\n".join(lines)


def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True):
    def docstring_decorator(fn):
        docstrings = fn.__doc__
        lines = docstrings.split("\n")
        i = 0
        while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None:
            i += 1
        if i < len(lines):
            indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0]
            if use_model_types:
                indent = f"{indent}    "
            lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types)
            docstrings = "\n".join(lines)
        else:
            raise ValueError(
                f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current docstring is:\n{docstrings}"
            )
        fn.__doc__ = docstrings
        return fn

    return docstring_decorator


Julien Chaumond's avatar
Julien Chaumond committed
468
class AutoConfig:
Lysandre Debut's avatar
Lysandre Debut committed
469
    r"""
470
471
    This is a generic configuration class that will be instantiated as one of the configuration classes of the library
    when created with the :meth:`~transformers.AutoConfig.from_pretrained` class method.
472

473
    This class cannot be instantiated directly using ``__init__()`` (throws an error).
474
    """
475

476
    def __init__(self):
477
478
479
480
        raise EnvironmentError(
            "AutoConfig is designed to be instantiated "
            "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
        )
481

482
    @classmethod
483
484
485
486
    def for_model(cls, model_type: str, *args, **kwargs):
        if model_type in CONFIG_MAPPING:
            config_class = CONFIG_MAPPING[model_type]
            return config_class(*args, **kwargs)
487
        raise ValueError(
488
            f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
489
        )
490

491
    @classmethod
492
    @replace_list_option_in_docstrings()
493
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
494
495
        r"""
        Instantiate one of the configuration classes of the library from a pretrained model configuration.
496

Sylvain Gugger's avatar
Sylvain Gugger committed
497
498
        The configuration class to instantiate is selected based on the :obj:`model_type` property of the config object
        that is loaded, or when it's missing, by falling back to using pattern matching on
499
        :obj:`pretrained_model_name_or_path`:
500

501
        List options
Lysandre Debut's avatar
Lysandre Debut committed
502
503

        Args:
504
            pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
505
506
                Can be either:

507
508
509
                    - A string, the `model id` of a pretrained model configuration hosted inside a model repo on
                      huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
                      namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
510
511
                    - A path to a `directory` containing a configuration file saved using the
                      :meth:`~transformers.PretrainedConfig.save_pretrained` method, or the
512
                      :meth:`~transformers.PreTrainedModel.save_pretrained` method, e.g., ``./my_model_directory/``.
513
514
                    - A path or url to a saved configuration JSON `file`, e.g.,
                      ``./my_model_directory/configuration.json``.
515
            cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
516
517
518
519
520
521
522
523
524
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download the model weights and configuration files and override the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies (:obj:`Dict[str, str]`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
525
526
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
Julien Chaumond's avatar
Julien Chaumond committed
527
528
529
530
            revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
                identifier allowed by git.
531
532
533
534
535
536
            return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
                If :obj:`False`, then this function returns just the final configuration object.

                If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs`
                is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e.,
                the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored.
537
538
539
540
            trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
                should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
                will execute code present on the Hub on your local machine.
541
542
            kwargs(additional keyword arguments, `optional`):
                The values in kwargs of any keys which are configuration attributes will be used to override the loaded
Sylvain Gugger's avatar
Sylvain Gugger committed
543
544
                values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
                by the ``return_unused_kwargs`` keyword parameter.
Lysandre Debut's avatar
Lysandre Debut committed
545

546
        Examples::
547

548
            >>> from transformers import AutoConfig
549

550
            >>> # Download configuration from huggingface.co and cache.
551
            >>> config = AutoConfig.from_pretrained('bert-base-uncased')
Lysandre Debut's avatar
Lysandre Debut committed
552

553
            >>> # Download configuration from huggingface.co (user-uploaded) and cache.
554
            >>> config = AutoConfig.from_pretrained('dbmdz/bert-base-german-cased')
Lysandre Debut's avatar
Lysandre Debut committed
555

556
557
            >>> # If configuration file is in a directory (e.g., was saved using `save_pretrained('./test/saved_model/')`).
            >>> config = AutoConfig.from_pretrained('./test/bert_saved_model/')
558

559
560
            >>> # Load a specific configuration file.
            >>> config = AutoConfig.from_pretrained('./test/bert_saved_model/my_configuration.json')
561

562
563
564
565
566
567
568
569
570
            >>> # Change some config attributes when loading a pretrained config.
            >>> config = AutoConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False)
            >>> config.output_attentions
            True
            >>> config, unused_kwargs = AutoConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False, return_unused_kwargs=True)
            >>> config.output_attentions
            True
            >>> config.unused_kwargs
            {'foo': False}
571
        """
572
        kwargs["_from_auto"] = True
573
574
        kwargs["name_or_path"] = pretrained_model_name_or_path
        trust_remote_code = kwargs.pop("trust_remote_code", False)
575
        config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
        if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
            if not trust_remote_code:
                raise ValueError(
                    f"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that repo "
                    "on your local machine. Make sure you have read the code there to avoid malicious use, then set "
                    "the option `trust_remote_code=True` to remove this error."
                )
            if kwargs.get("revision", None) is None:
                logger.warn(
                    "Explicitly passing a `revision` is encouraged when loading a configuration with custom code to "
                    "ensure no malicious code has been contributed in a newer revision."
                )
            class_ref = config_dict["auto_map"]["AutoConfig"]
            module_file, class_name = class_ref.split(".")
            config_class = get_class_from_dynamic_module(
                pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
            )
            return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
        elif "model_type" in config_dict:
595
596
597
598
599
            config_class = CONFIG_MAPPING[config_dict["model_type"]]
            return config_class.from_dict(config_dict, **kwargs)
        else:
            # Fallback: use pattern matching on the string.
            for pattern, config_class in CONFIG_MAPPING.items():
600
                if pattern in str(pretrained_model_name_or_path):
601
602
                    return config_class.from_dict(config_dict, **kwargs)

603
        raise ValueError(
604
            f"Unrecognized model in {pretrained_model_name_or_path}. "
605
            f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
606
            f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
607
        )
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624

    @staticmethod
    def register(model_type, config):
        """
        Register a new configuration for this class.

        Args:
            model_type (:obj:`str`): The model type like "bert" or "gpt".
            config (:class:`~transformers.PretrainedConfig`): The config to register.
        """
        if issubclass(config, PretrainedConfig) and config.model_type != model_type:
            raise ValueError(
                "The config you are passing has a `model_type` attribute that is not consistent with the model type "
                f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
                "match!"
            )
        CONFIG_MAPPING.register(model_type, config)