configuration_auto.py 21.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
""" Auto Config class. """
16

17
import re
18
from collections import OrderedDict
19

Sylvain Gugger's avatar
Sylvain Gugger committed
20
21
22
23
24
from ...configuration_utils import PretrainedConfig
from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from ..bert_generation.configuration_bert_generation import BertGenerationConfig
Vasudev Gupta's avatar
Vasudev Gupta committed
25
from ..big_bird.configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig
Vasudev Gupta's avatar
Vasudev Gupta committed
26
27
28
29
from ..bigbird_pegasus.configuration_bigbird_pegasus import (
    BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
    BigBirdPegasusConfig,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
30
from ..blenderbot.configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
31
32
33
34
from ..blenderbot_small.configuration_blenderbot_small import (
    BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
    BlenderbotSmallConfig,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
35
from ..camembert.configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
Suraj Patil's avatar
Suraj Patil committed
36
from ..clip.configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig
abhishek thakur's avatar
abhishek thakur committed
37
from ..convbert.configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig
Sylvain Gugger's avatar
Sylvain Gugger committed
38
39
from ..ctrl.configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from ..deberta.configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
40
from ..deberta_v2.configuration_deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config
NielsRogge's avatar
NielsRogge committed
41
from ..deit.configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig
Sylvain Gugger's avatar
Sylvain Gugger committed
42
43
44
45
46
47
48
49
from ..distilbert.configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from ..dpr.configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
from ..electra.configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from ..encoder_decoder.configuration_encoder_decoder import EncoderDecoderConfig
from ..flaubert.configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
Suraj Patil's avatar
Suraj Patil committed
50
from ..gpt_neo.configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
Sehoon Kim's avatar
Sehoon Kim committed
51
from ..ibert.configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
Sylvain Gugger's avatar
Sylvain Gugger committed
52
from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
Patrick von Platen's avatar
Patrick von Platen committed
53
from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
Sylvain Gugger's avatar
Sylvain Gugger committed
54
from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
NielsRogge's avatar
NielsRogge committed
55
from ..luke.configuration_luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig
Sylvain Gugger's avatar
Sylvain Gugger committed
56
from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig
Suraj Patil's avatar
Suraj Patil committed
57
from ..m2m_100.configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
Sylvain Gugger's avatar
Sylvain Gugger committed
58
59
from ..marian.configuration_marian import MarianConfig
from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
60
from ..megatron_bert.configuration_megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig
Sylvain Gugger's avatar
Sylvain Gugger committed
61
from ..mobilebert.configuration_mobilebert import MobileBertConfig
StillKeepTry's avatar
StillKeepTry committed
62
from ..mpnet.configuration_mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig
Patrick von Platen's avatar
Patrick von Platen committed
63
from ..mt5.configuration_mt5 import MT5Config
Sylvain Gugger's avatar
Sylvain Gugger committed
64
65
66
67
68
69
70
from ..openai.configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from ..pegasus.configuration_pegasus import PegasusConfig
from ..prophetnet.configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig
from ..rag.configuration_rag import RagConfig
from ..reformer.configuration_reformer import ReformerConfig
from ..retribert.configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from ..roberta.configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
Suraj Patil's avatar
Suraj Patil committed
71
72
73
74
from ..speech_to_text.configuration_speech_to_text import (
    SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
    Speech2TextConfig,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
75
76
from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
NielsRogge's avatar
NielsRogge committed
77
from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
Sylvain Gugger's avatar
Sylvain Gugger committed
78
from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
79
from ..vit.configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
Patrick von Platen's avatar
Patrick von Platen committed
80
from ..wav2vec2.configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
Sylvain Gugger's avatar
Sylvain Gugger committed
81
82
83
84
85
86
87
from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from ..xlm_prophetnet.configuration_xlm_prophetnet import (
    XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
    XLMProphetNetConfig,
)
from ..xlm_roberta.configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
from ..xlnet.configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
Aymeric Augustin's avatar
Aymeric Augustin committed
88

89

90
91
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
    (key, value)
92
    for pretrained_map in [
93
        # Add archive maps here
Suraj Patil's avatar
Suraj Patil committed
94
        CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
Vasudev Gupta's avatar
Vasudev Gupta committed
95
        BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
NielsRogge's avatar
NielsRogge committed
96
        DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
NielsRogge's avatar
NielsRogge committed
97
        LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP,
Suraj Patil's avatar
Suraj Patil committed
98
        GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP,
Vasudev Gupta's avatar
Vasudev Gupta committed
99
        BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP,
100
        MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Suraj Patil's avatar
Suraj Patil committed
101
        SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
102
        VIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Patrick von Platen's avatar
Patrick von Platen committed
103
        WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
Suraj Patil's avatar
Suraj Patil committed
104
        M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP,
abhishek thakur's avatar
abhishek thakur committed
105
        CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Patrick von Platen's avatar
Patrick von Platen committed
106
        LED_PRETRAINED_CONFIG_ARCHIVE_MAP,
107
        BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
108
        BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Sam Shleifer's avatar
Sam Shleifer committed
109
        BART_PRETRAINED_CONFIG_ARCHIVE_MAP,
Sam Shleifer's avatar
Sam Shleifer committed
110
        BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP,
111
        MBART_PRETRAINED_CONFIG_ARCHIVE_MAP,
112
113
114
115
116
117
118
119
120
121
122
        OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
        TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
        GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
        CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
        XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
        XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
        ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
        DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
        ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
        CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
        T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
123
        XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
Hang Le's avatar
Hang Le committed
124
        FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
125
        FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Lysandre Debut's avatar
Lysandre Debut committed
126
        ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
Iz Beltagy's avatar
Iz Beltagy committed
127
        LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
Yacine Jernite's avatar
Yacine Jernite committed
128
        RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Sylvain Gugger's avatar
Sylvain Gugger committed
129
        FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP,
130
        LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Minghao Li's avatar
Minghao Li committed
131
        LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
Ola Piktus's avatar
Ola Piktus committed
132
        DPR_PRETRAINED_CONFIG_ARCHIVE_MAP,
Pengcheng He's avatar
Pengcheng He committed
133
        DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
134
        DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP,
135
        SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Weizhen's avatar
Weizhen committed
136
137
        XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
        PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
StillKeepTry's avatar
StillKeepTry committed
138
        MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
NielsRogge's avatar
NielsRogge committed
139
        TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP,
Sehoon Kim's avatar
Sehoon Kim committed
140
        IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
141
142
143
    ]
    for key, value, in pretrained_map.items()
)
144
145


146
147
CONFIG_MAPPING = OrderedDict(
    [
148
        # Add configs here
Suraj Patil's avatar
Suraj Patil committed
149
        ("clip", CLIPConfig),
Vasudev Gupta's avatar
Vasudev Gupta committed
150
        ("bigbird_pegasus", BigBirdPegasusConfig),
NielsRogge's avatar
NielsRogge committed
151
        ("deit", DeiTConfig),
NielsRogge's avatar
NielsRogge committed
152
        ("luke", LukeConfig),
Suraj Patil's avatar
Suraj Patil committed
153
        ("gpt_neo", GPTNeoConfig),
Vasudev Gupta's avatar
Vasudev Gupta committed
154
        ("big_bird", BigBirdConfig),
Suraj Patil's avatar
Suraj Patil committed
155
        ("speech_to_text", Speech2TextConfig),
156
        ("vit", ViTConfig),
Patrick von Platen's avatar
Patrick von Platen committed
157
        ("wav2vec2", Wav2Vec2Config),
Suraj Patil's avatar
Suraj Patil committed
158
        ("m2m_100", M2M100Config),
abhishek thakur's avatar
abhishek thakur committed
159
        ("convbert", ConvBertConfig),
Patrick von Platen's avatar
Patrick von Platen committed
160
        ("led", LEDConfig),
161
        ("blenderbot-small", BlenderbotSmallConfig),
162
        ("retribert", RetriBertConfig),
Sehoon Kim's avatar
Sehoon Kim committed
163
        ("ibert", IBertConfig),
Patrick von Platen's avatar
Patrick von Platen committed
164
        ("mt5", MT5Config),
165
166
167
168
169
170
171
        ("t5", T5Config),
        ("mobilebert", MobileBertConfig),
        ("distilbert", DistilBertConfig),
        ("albert", AlbertConfig),
        ("bert-generation", BertGenerationConfig),
        ("camembert", CamembertConfig),
        ("xlm-roberta", XLMRobertaConfig),
172
        ("pegasus", PegasusConfig),
173
174
        ("marian", MarianConfig),
        ("mbart", MBartConfig),
175
        ("megatron_bert", MegatronBertConfig),
StillKeepTry's avatar
StillKeepTry committed
176
        ("mpnet", MPNetConfig),
177
        ("bart", BartConfig),
Sam Shleifer's avatar
Sam Shleifer committed
178
        ("blenderbot", BlenderbotConfig),
179
180
181
        ("reformer", ReformerConfig),
        ("longformer", LongformerConfig),
        ("roberta", RobertaConfig),
182
        ("deberta-v2", DebertaV2Config),
Pengcheng He's avatar
Pengcheng He committed
183
        ("deberta", DebertaConfig),
184
        ("flaubert", FlaubertConfig),
185
        ("fsmt", FSMTConfig),
186
        ("squeezebert", SqueezeBertConfig),
187
188
189
190
191
        ("bert", BertConfig),
        ("openai-gpt", OpenAIGPTConfig),
        ("gpt2", GPT2Config),
        ("transfo-xl", TransfoXLConfig),
        ("xlnet", XLNetConfig),
Weizhen's avatar
Weizhen committed
192
193
        ("xlm-prophetnet", XLMProphetNetConfig),
        ("prophetnet", ProphetNetConfig),
194
195
196
197
198
199
        ("xlm", XLMConfig),
        ("ctrl", CTRLConfig),
        ("electra", ElectraConfig),
        ("encoder-decoder", EncoderDecoderConfig),
        ("funnel", FunnelConfig),
        ("lxmert", LxmertConfig),
Ola Piktus's avatar
Ola Piktus committed
200
        ("dpr", DPRConfig),
Minghao Li's avatar
Minghao Li committed
201
        ("layoutlm", LayoutLMConfig),
Ola Piktus's avatar
Ola Piktus committed
202
        ("rag", RagConfig),
NielsRogge's avatar
NielsRogge committed
203
        ("tapas", TapasConfig),
204
205
206
207
208
    ]
)

MODEL_NAMES_MAPPING = OrderedDict(
    [
209
        # Add full (and cased) model names here
Suraj Patil's avatar
Suraj Patil committed
210
        ("clip", "CLIP"),
Vasudev Gupta's avatar
Vasudev Gupta committed
211
        ("bigbird_pegasus", "BigBirdPegasus"),
NielsRogge's avatar
NielsRogge committed
212
        ("deit", "DeiT"),
NielsRogge's avatar
NielsRogge committed
213
        ("luke", "LUKE"),
Suraj Patil's avatar
Suraj Patil committed
214
        ("gpt_neo", "GPT Neo"),
Vasudev Gupta's avatar
Vasudev Gupta committed
215
        ("big_bird", "BigBird"),
Suraj Patil's avatar
Suraj Patil committed
216
        ("speech_to_text", "Speech2Text"),
217
        ("vit", "ViT"),
Patrick von Platen's avatar
Patrick von Platen committed
218
        ("wav2vec2", "Wav2Vec2"),
Suraj Patil's avatar
Suraj Patil committed
219
        ("m2m_100", "M2M100"),
abhishek thakur's avatar
abhishek thakur committed
220
        ("convbert", "ConvBERT"),
Patrick von Platen's avatar
Patrick von Platen committed
221
        ("led", "LED"),
222
        ("blenderbot-small", "BlenderbotSmall"),
223
        ("retribert", "RetriBERT"),
Sehoon Kim's avatar
Sehoon Kim committed
224
        ("ibert", "I-BERT"),
225
226
227
228
229
230
231
232
        ("t5", "T5"),
        ("mobilebert", "MobileBERT"),
        ("distilbert", "DistilBERT"),
        ("albert", "ALBERT"),
        ("bert-generation", "Bert Generation"),
        ("camembert", "CamemBERT"),
        ("xlm-roberta", "XLM-RoBERTa"),
        ("pegasus", "Pegasus"),
Sam Shleifer's avatar
Sam Shleifer committed
233
        ("blenderbot", "Blenderbot"),
234
235
        ("marian", "Marian"),
        ("mbart", "mBART"),
236
        ("megatron_bert", "MegatronBert"),
237
238
239
240
241
        ("bart", "BART"),
        ("reformer", "Reformer"),
        ("longformer", "Longformer"),
        ("roberta", "RoBERTa"),
        ("flaubert", "FlauBERT"),
242
        ("fsmt", "FairSeq Machine-Translation"),
243
        ("squeezebert", "SqueezeBERT"),
244
245
246
247
248
249
250
251
252
253
254
        ("bert", "BERT"),
        ("openai-gpt", "OpenAI GPT"),
        ("gpt2", "OpenAI GPT-2"),
        ("transfo-xl", "Transformer-XL"),
        ("xlnet", "XLNet"),
        ("xlm", "XLM"),
        ("ctrl", "CTRL"),
        ("electra", "ELECTRA"),
        ("encoder-decoder", "Encoder decoder"),
        ("funnel", "Funnel Transformer"),
        ("lxmert", "LXMERT"),
255
        ("deberta-v2", "DeBERTa-v2"),
Pengcheng He's avatar
Pengcheng He committed
256
        ("deberta", "DeBERTa"),
Minghao Li's avatar
Minghao Li committed
257
        ("layoutlm", "LayoutLM"),
Ola Piktus's avatar
Ola Piktus committed
258
259
        ("dpr", "DPR"),
        ("rag", "RAG"),
Weizhen's avatar
Weizhen committed
260
261
        ("xlm-prophetnet", "XLMProphetNet"),
        ("prophetnet", "ProphetNet"),
Patrick von Platen's avatar
Patrick von Platen committed
262
        ("mt5", "mT5"),
StillKeepTry's avatar
StillKeepTry committed
263
        ("mpnet", "MPNet"),
NielsRogge's avatar
NielsRogge committed
264
        ("tapas", "TAPAS"),
265
266
267
268
    ]
)


269
270
271
272
273
274
def _get_class_name(model_class):
    if isinstance(model_class, (list, tuple)):
        return " or ".join([f":class:`~transformers.{c.__name__}`" for c in model_class])
    return f":class:`~transformers.{model_class.__name__}`"


275
276
277
278
279
def _list_model_options(indent, config_to_class=None, use_model_types=True):
    if config_to_class is None and not use_model_types:
        raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
    if use_model_types:
        if config_to_class is None:
280
281
282
283
            model_type_to_name = {
                model_type: f":class:`~transformers.{config.__name__}`"
                for model_type, config in CONFIG_MAPPING.items()
            }
284
285
        else:
            model_type_to_name = {
286
                model_type: _get_class_name(config_to_class[config])
287
288
289
290
                for model_type, config in CONFIG_MAPPING.items()
                if config in config_to_class
            }
        lines = [
291
            f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
292
            for model_type in sorted(model_type_to_name.keys())
293
294
        ]
    else:
295
        config_to_name = {config.__name__: _get_class_name(clas) for config, clas in config_to_class.items()}
296
297
298
299
        config_to_model_name = {
            config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
        }
        lines = [
300
            f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
301
            for config_name in sorted(config_to_name.keys())
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        ]
    return "\n".join(lines)


def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True):
    def docstring_decorator(fn):
        docstrings = fn.__doc__
        lines = docstrings.split("\n")
        i = 0
        while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None:
            i += 1
        if i < len(lines):
            indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0]
            if use_model_types:
                indent = f"{indent}    "
            lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types)
            docstrings = "\n".join(lines)
        else:
            raise ValueError(
                f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current docstring is:\n{docstrings}"
            )
        fn.__doc__ = docstrings
        return fn

    return docstring_decorator


Julien Chaumond's avatar
Julien Chaumond committed
329
class AutoConfig:
Lysandre Debut's avatar
Lysandre Debut committed
330
    r"""
331
332
    This is a generic configuration class that will be instantiated as one of the configuration classes of the library
    when created with the :meth:`~transformers.AutoConfig.from_pretrained` class method.
333

334
    This class cannot be instantiated directly using ``__init__()`` (throws an error).
335
    """
336

337
    def __init__(self):
338
339
340
341
        raise EnvironmentError(
            "AutoConfig is designed to be instantiated "
            "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
        )
342

343
    @classmethod
344
345
346
347
    def for_model(cls, model_type: str, *args, **kwargs):
        if model_type in CONFIG_MAPPING:
            config_class = CONFIG_MAPPING[model_type]
            return config_class(*args, **kwargs)
348
        raise ValueError(
349
            f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
350
        )
351

352
    @classmethod
353
    @replace_list_option_in_docstrings()
354
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
355
356
        r"""
        Instantiate one of the configuration classes of the library from a pretrained model configuration.
357

Sylvain Gugger's avatar
Sylvain Gugger committed
358
359
        The configuration class to instantiate is selected based on the :obj:`model_type` property of the config object
        that is loaded, or when it's missing, by falling back to using pattern matching on
360
        :obj:`pretrained_model_name_or_path`:
361

362
        List options
Lysandre Debut's avatar
Lysandre Debut committed
363
364

        Args:
365
            pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
366
367
                Can be either:

368
369
370
                    - A string, the `model id` of a pretrained model configuration hosted inside a model repo on
                      huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
                      namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
371
372
                    - A path to a `directory` containing a configuration file saved using the
                      :meth:`~transformers.PretrainedConfig.save_pretrained` method, or the
373
                      :meth:`~transformers.PreTrainedModel.save_pretrained` method, e.g., ``./my_model_directory/``.
374
375
                    - A path or url to a saved configuration JSON `file`, e.g.,
                      ``./my_model_directory/configuration.json``.
376
            cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
377
378
379
380
381
382
383
384
385
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download the model weights and configuration files and override the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies (:obj:`Dict[str, str]`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
386
387
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
Julien Chaumond's avatar
Julien Chaumond committed
388
389
390
391
            revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
                identifier allowed by git.
392
393
394
395
396
397
398
399
            return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
                If :obj:`False`, then this function returns just the final configuration object.

                If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs`
                is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e.,
                the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored.
            kwargs(additional keyword arguments, `optional`):
                The values in kwargs of any keys which are configuration attributes will be used to override the loaded
Sylvain Gugger's avatar
Sylvain Gugger committed
400
401
                values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
                by the ``return_unused_kwargs`` keyword parameter.
Lysandre Debut's avatar
Lysandre Debut committed
402

403
        Examples::
404

405
            >>> from transformers import AutoConfig
406

407
            >>> # Download configuration from huggingface.co and cache.
408
            >>> config = AutoConfig.from_pretrained('bert-base-uncased')
Lysandre Debut's avatar
Lysandre Debut committed
409

410
            >>> # Download configuration from huggingface.co (user-uploaded) and cache.
411
            >>> config = AutoConfig.from_pretrained('dbmdz/bert-base-german-cased')
Lysandre Debut's avatar
Lysandre Debut committed
412

413
414
            >>> # If configuration file is in a directory (e.g., was saved using `save_pretrained('./test/saved_model/')`).
            >>> config = AutoConfig.from_pretrained('./test/bert_saved_model/')
415

416
417
            >>> # Load a specific configuration file.
            >>> config = AutoConfig.from_pretrained('./test/bert_saved_model/my_configuration.json')
418

419
420
421
422
423
424
425
426
427
            >>> # Change some config attributes when loading a pretrained config.
            >>> config = AutoConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False)
            >>> config.output_attentions
            True
            >>> config, unused_kwargs = AutoConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False, return_unused_kwargs=True)
            >>> config.output_attentions
            True
            >>> config.unused_kwargs
            {'foo': False}
428
        """
429
        kwargs["_from_auto"] = True
430
        config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
431
432
433
434
435
436
        if "model_type" in config_dict:
            config_class = CONFIG_MAPPING[config_dict["model_type"]]
            return config_class.from_dict(config_dict, **kwargs)
        else:
            # Fallback: use pattern matching on the string.
            for pattern, config_class in CONFIG_MAPPING.items():
437
                if pattern in str(pretrained_model_name_or_path):
438
439
                    return config_class.from_dict(config_dict, **kwargs)

440
        raise ValueError(
441
            f"Unrecognized model in {pretrained_model_name_or_path}. "
442
            "Should have a `model_type` key in its config.json, or contain one of the following strings "
443
            f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
444
        )