"tests/test_modeling_tf_flaubert.py" did not exist on "e392ba6938f50655a195ea7ec8a260b1e9fc6058"
configuration_auto.py 16.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
""" Auto Config class. """
16

17
import re
18
from collections import OrderedDict
19

Sylvain Gugger's avatar
Sylvain Gugger committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from ...configuration_utils import PretrainedConfig
from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from ..bert_generation.configuration_bert_generation import BertGenerationConfig
from ..blenderbot.configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
from ..camembert.configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from ..ctrl.configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from ..deberta.configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
from ..distilbert.configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from ..dpr.configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
from ..electra.configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from ..encoder_decoder.configuration_encoder_decoder import EncoderDecoderConfig
from ..flaubert.configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig
from ..marian.configuration_marian import MarianConfig
from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
from ..mobilebert.configuration_mobilebert import MobileBertConfig
from ..openai.configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from ..pegasus.configuration_pegasus import PegasusConfig
from ..prophetnet.configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig
from ..rag.configuration_rag import RagConfig
from ..reformer.configuration_reformer import ReformerConfig
from ..retribert.configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from ..roberta.configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from ..xlm_prophetnet.configuration_xlm_prophetnet import (
    XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
    XLMProphetNetConfig,
)
from ..xlm_roberta.configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
from ..xlnet.configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
Aymeric Augustin's avatar
Aymeric Augustin committed
60

61

62
63
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
    (key, value)
64
    for pretrained_map in [
65
        # Add archive maps here
66
        BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Sam Shleifer's avatar
Sam Shleifer committed
67
        BART_PRETRAINED_CONFIG_ARCHIVE_MAP,
Sam Shleifer's avatar
Sam Shleifer committed
68
        BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP,
69
        MBART_PRETRAINED_CONFIG_ARCHIVE_MAP,
70
71
72
73
74
75
76
77
78
79
80
        OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
        TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
        GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
        CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
        XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
        XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
        ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
        DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
        ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
        CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
        T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
81
        XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
Hang Le's avatar
Hang Le committed
82
        FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
83
        FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Lysandre Debut's avatar
Lysandre Debut committed
84
        ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
Iz Beltagy's avatar
Iz Beltagy committed
85
        LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
Yacine Jernite's avatar
Yacine Jernite committed
86
        RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Sylvain Gugger's avatar
Sylvain Gugger committed
87
        FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP,
88
        LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Minghao Li's avatar
Minghao Li committed
89
        LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
Ola Piktus's avatar
Ola Piktus committed
90
        DPR_PRETRAINED_CONFIG_ARCHIVE_MAP,
Pengcheng He's avatar
Pengcheng He committed
91
        DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
92
        SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Weizhen's avatar
Weizhen committed
93
94
        XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
        PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
95
96
97
    ]
    for key, value, in pretrained_map.items()
)
98
99


100
101
CONFIG_MAPPING = OrderedDict(
    [
102
        # Add configs here
103
104
105
106
107
108
109
110
        ("retribert", RetriBertConfig),
        ("t5", T5Config),
        ("mobilebert", MobileBertConfig),
        ("distilbert", DistilBertConfig),
        ("albert", AlbertConfig),
        ("bert-generation", BertGenerationConfig),
        ("camembert", CamembertConfig),
        ("xlm-roberta", XLMRobertaConfig),
111
        ("pegasus", PegasusConfig),
112
113
114
        ("marian", MarianConfig),
        ("mbart", MBartConfig),
        ("bart", BartConfig),
Sam Shleifer's avatar
Sam Shleifer committed
115
        ("blenderbot", BlenderbotConfig),
116
117
118
        ("reformer", ReformerConfig),
        ("longformer", LongformerConfig),
        ("roberta", RobertaConfig),
Pengcheng He's avatar
Pengcheng He committed
119
        ("deberta", DebertaConfig),
120
        ("flaubert", FlaubertConfig),
121
        ("fsmt", FSMTConfig),
122
        ("squeezebert", SqueezeBertConfig),
123
124
125
126
127
        ("bert", BertConfig),
        ("openai-gpt", OpenAIGPTConfig),
        ("gpt2", GPT2Config),
        ("transfo-xl", TransfoXLConfig),
        ("xlnet", XLNetConfig),
Weizhen's avatar
Weizhen committed
128
129
        ("xlm-prophetnet", XLMProphetNetConfig),
        ("prophetnet", ProphetNetConfig),
130
131
132
133
134
135
        ("xlm", XLMConfig),
        ("ctrl", CTRLConfig),
        ("electra", ElectraConfig),
        ("encoder-decoder", EncoderDecoderConfig),
        ("funnel", FunnelConfig),
        ("lxmert", LxmertConfig),
Ola Piktus's avatar
Ola Piktus committed
136
        ("dpr", DPRConfig),
Minghao Li's avatar
Minghao Li committed
137
        ("layoutlm", LayoutLMConfig),
Ola Piktus's avatar
Ola Piktus committed
138
        ("rag", RagConfig),
139
140
141
142
143
    ]
)

MODEL_NAMES_MAPPING = OrderedDict(
    [
144
        # Add full (and cased) model names here
145
146
147
148
149
150
151
152
153
        ("retribert", "RetriBERT"),
        ("t5", "T5"),
        ("mobilebert", "MobileBERT"),
        ("distilbert", "DistilBERT"),
        ("albert", "ALBERT"),
        ("bert-generation", "Bert Generation"),
        ("camembert", "CamemBERT"),
        ("xlm-roberta", "XLM-RoBERTa"),
        ("pegasus", "Pegasus"),
Sam Shleifer's avatar
Sam Shleifer committed
154
        ("blenderbot", "Blenderbot"),
155
156
157
158
159
160
161
        ("marian", "Marian"),
        ("mbart", "mBART"),
        ("bart", "BART"),
        ("reformer", "Reformer"),
        ("longformer", "Longformer"),
        ("roberta", "RoBERTa"),
        ("flaubert", "FlauBERT"),
162
        ("fsmt", "FairSeq Machine-Translation"),
163
        ("squeezebert", "SqueezeBERT"),
164
165
166
167
168
169
170
171
172
173
174
        ("bert", "BERT"),
        ("openai-gpt", "OpenAI GPT"),
        ("gpt2", "OpenAI GPT-2"),
        ("transfo-xl", "Transformer-XL"),
        ("xlnet", "XLNet"),
        ("xlm", "XLM"),
        ("ctrl", "CTRL"),
        ("electra", "ELECTRA"),
        ("encoder-decoder", "Encoder decoder"),
        ("funnel", "Funnel Transformer"),
        ("lxmert", "LXMERT"),
Pengcheng He's avatar
Pengcheng He committed
175
        ("deberta", "DeBERTa"),
Minghao Li's avatar
Minghao Li committed
176
        ("layoutlm", "LayoutLM"),
Ola Piktus's avatar
Ola Piktus committed
177
178
        ("dpr", "DPR"),
        ("rag", "RAG"),
Weizhen's avatar
Weizhen committed
179
180
        ("xlm-prophetnet", "XLMProphetNet"),
        ("prophetnet", "ProphetNet"),
181
182
183
184
    ]
)


185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def _list_model_options(indent, config_to_class=None, use_model_types=True):
    if config_to_class is None and not use_model_types:
        raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
    if use_model_types:
        if config_to_class is None:
            model_type_to_name = {model_type: config.__name__ for model_type, config in CONFIG_MAPPING.items()}
        else:
            model_type_to_name = {
                model_type: config_to_class[config].__name__
                for model_type, config in CONFIG_MAPPING.items()
                if config in config_to_class
            }
        lines = [
            f"{indent}- **{model_type}** -- :class:`~transformers.{cls_name}` ({MODEL_NAMES_MAPPING[model_type]} model)"
            for model_type, cls_name in model_type_to_name.items()
        ]
    else:
        config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()}
        config_to_model_name = {
            config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
        }
        lines = [
            f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{cls_name}` ({config_to_model_name[config_name]} model)"
            for config_name, cls_name in config_to_name.items()
        ]
    return "\n".join(lines)


def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True):
    def docstring_decorator(fn):
        docstrings = fn.__doc__
        lines = docstrings.split("\n")
        i = 0
        while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None:
            i += 1
        if i < len(lines):
            indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0]
            if use_model_types:
                indent = f"{indent}    "
            lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types)
            docstrings = "\n".join(lines)
        else:
            raise ValueError(
                f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current docstring is:\n{docstrings}"
            )
        fn.__doc__ = docstrings
        return fn

    return docstring_decorator


Julien Chaumond's avatar
Julien Chaumond committed
236
class AutoConfig:
Lysandre Debut's avatar
Lysandre Debut committed
237
    r"""
238
239
    This is a generic configuration class that will be instantiated as one of the configuration classes of the library
    when created with the :meth:`~transformers.AutoConfig.from_pretrained` class method.
240

241
    This class cannot be instantiated directly using ``__init__()`` (throws an error).
242
    """
243

244
    def __init__(self):
245
246
247
248
        raise EnvironmentError(
            "AutoConfig is designed to be instantiated "
            "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
        )
249

250
    @classmethod
251
252
253
254
    def for_model(cls, model_type: str, *args, **kwargs):
        if model_type in CONFIG_MAPPING:
            config_class = CONFIG_MAPPING[model_type]
            return config_class(*args, **kwargs)
255
        raise ValueError(
256
            "Unrecognized model identifier: {}. Should contain one of {}".format(
257
258
                model_type, ", ".join(CONFIG_MAPPING.keys())
            )
259
        )
260

261
    @classmethod
262
    @replace_list_option_in_docstrings()
263
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
264
265
        r"""
        Instantiate one of the configuration classes of the library from a pretrained model configuration.
266

Sylvain Gugger's avatar
Sylvain Gugger committed
267
268
        The configuration class to instantiate is selected based on the :obj:`model_type` property of the config object
        that is loaded, or when it's missing, by falling back to using pattern matching on
269
        :obj:`pretrained_model_name_or_path`:
270

271
        List options
Lysandre Debut's avatar
Lysandre Debut committed
272
273

        Args:
274
275
276
277
278
279
280
281
282
            pretrained_model_name_or_path (:obj:`str`):
                Can be either:

                    - A string with the `shortcut name` of a pretrained model configuration to load from cache or
                      download, e.g., ``bert-base-uncased``.
                    - A string with the `identifier name` of a pretrained model configuration that was user-uploaded to
                      our S3, e.g., ``dbmdz/bert-base-german-cased``.
                    - A path to a `directory` containing a configuration file saved using the
                      :meth:`~transformers.PretrainedConfig.save_pretrained` method, or the
283
                      :meth:`~transformers.PreTrainedModel.save_pretrained` method, e.g., ``./my_model_directory/``.
284
285
286
287
288
289
290
291
292
293
294
295
                    - A path or url to a saved configuration JSON `file`, e.g.,
                      ``./my_model_directory/configuration.json``.
            cache_dir (:obj:`str`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download the model weights and configuration files and override the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies (:obj:`Dict[str, str]`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
296
297
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
Julien Chaumond's avatar
Julien Chaumond committed
298
299
300
301
            revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
                identifier allowed by git.
302
303
304
305
306
307
308
309
            return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
                If :obj:`False`, then this function returns just the final configuration object.

                If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs`
                is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e.,
                the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored.
            kwargs(additional keyword arguments, `optional`):
                The values in kwargs of any keys which are configuration attributes will be used to override the loaded
Sylvain Gugger's avatar
Sylvain Gugger committed
310
311
                values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
                by the ``return_unused_kwargs`` keyword parameter.
Lysandre Debut's avatar
Lysandre Debut committed
312

313
        Examples::
314

315
            >>> from transformers import AutoConfig
316

317
318
            >>> # Download configuration from S3 and cache.
            >>> config = AutoConfig.from_pretrained('bert-base-uncased')
Lysandre Debut's avatar
Lysandre Debut committed
319

320
321
            >>> # Download configuration from S3 (user-uploaded) and cache.
            >>> config = AutoConfig.from_pretrained('dbmdz/bert-base-german-cased')
Lysandre Debut's avatar
Lysandre Debut committed
322

323
324
            >>> # If configuration file is in a directory (e.g., was saved using `save_pretrained('./test/saved_model/')`).
            >>> config = AutoConfig.from_pretrained('./test/bert_saved_model/')
325

326
327
            >>> # Load a specific configuration file.
            >>> config = AutoConfig.from_pretrained('./test/bert_saved_model/my_configuration.json')
328

329
330
331
332
333
334
335
336
337
            >>> # Change some config attributes when loading a pretrained config.
            >>> config = AutoConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False)
            >>> config.output_attentions
            True
            >>> config, unused_kwargs = AutoConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False, return_unused_kwargs=True)
            >>> config.output_attentions
            True
            >>> config.unused_kwargs
            {'foo': False}
338
        """
339
        config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
340
341
342
343
344
345
346
347
348
349

        if "model_type" in config_dict:
            config_class = CONFIG_MAPPING[config_dict["model_type"]]
            return config_class.from_dict(config_dict, **kwargs)
        else:
            # Fallback: use pattern matching on the string.
            for pattern, config_class in CONFIG_MAPPING.items():
                if pattern in pretrained_model_name_or_path:
                    return config_class.from_dict(config_dict, **kwargs)

350
        raise ValueError(
351
352
353
            "Unrecognized model in {}. "
            "Should have a `model_type` key in its config.json, or contain one of the following strings "
            "in its name: {}".format(pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys()))
354
        )