modeling_auto.py 68.9 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """


import logging
Julien Chaumond's avatar
Julien Chaumond committed
19
from collections import OrderedDict
thomwolf's avatar
thomwolf committed
20

21
22
from .configuration_auto import (
    AlbertConfig,
23
    AutoConfig,
Sam Shleifer's avatar
Sam Shleifer committed
24
    BartConfig,
25
26
27
28
    BertConfig,
    CamembertConfig,
    CTRLConfig,
    DistilBertConfig,
Lysandre Debut's avatar
Lysandre Debut committed
29
    ElectraConfig,
Lysandre's avatar
Lysandre committed
30
    FlaubertConfig,
31
32
33
    GPT2Config,
    OpenAIGPTConfig,
    RobertaConfig,
34
    T5Config,
35
36
37
    TransfoXLConfig,
    XLMConfig,
    XLMRobertaConfig,
Aymeric Augustin's avatar
Aymeric Augustin committed
38
39
    XLNetConfig,
)
40
from .configuration_utils import PretrainedConfig
Aymeric Augustin's avatar
Aymeric Augustin committed
41
42
43
44
45
from .modeling_albert import (
    ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
    AlbertForMaskedLM,
    AlbertForQuestionAnswering,
    AlbertForSequenceClassification,
46
    AlbertForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
47
    AlbertModel,
48
)
49
50
51
52
53
54
from .modeling_bart import (
    BART_PRETRAINED_MODEL_ARCHIVE_MAP,
    BartForConditionalGeneration,
    BartForSequenceClassification,
    BartModel,
)
55
from .modeling_bert import (
Aymeric Augustin's avatar
Aymeric Augustin committed
56
    BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
57
    BertForMaskedLM,
Julien Chaumond's avatar
Julien Chaumond committed
58
    BertForMultipleChoice,
thomwolf's avatar
thomwolf committed
59
    BertForPreTraining,
60
    BertForQuestionAnswering,
Aymeric Augustin's avatar
Aymeric Augustin committed
61
    BertForSequenceClassification,
62
    BertForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
63
    BertModel,
64
)
Aymeric Augustin's avatar
Aymeric Augustin committed
65
66
67
from .modeling_camembert import (
    CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
    CamembertForMaskedLM,
Julien Chaumond's avatar
Julien Chaumond committed
68
    CamembertForMultipleChoice,
Aymeric Augustin's avatar
Aymeric Augustin committed
69
70
71
    CamembertForSequenceClassification,
    CamembertForTokenClassification,
    CamembertModel,
72
)
Aymeric Augustin's avatar
Aymeric Augustin committed
73
from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel, CTRLModel
74
from .modeling_distilbert import (
Aymeric Augustin's avatar
Aymeric Augustin committed
75
    DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
76
    DistilBertForMaskedLM,
Aymeric Augustin's avatar
Aymeric Augustin committed
77
    DistilBertForQuestionAnswering,
78
79
    DistilBertForSequenceClassification,
    DistilBertForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
80
    DistilBertModel,
81
)
Lysandre Debut's avatar
Lysandre Debut committed
82
83
84
85
86
87
88
from .modeling_electra import (
    ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
    ElectraForMaskedLM,
    ElectraForPreTraining,
    ElectraForTokenClassification,
    ElectraModel,
)
Lysandre's avatar
Lysandre committed
89
90
from .modeling_flaubert import (
    FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
91
    FlaubertForQuestionAnsweringSimple,
Lysandre's avatar
Lysandre committed
92
93
94
95
    FlaubertForSequenceClassification,
    FlaubertModel,
    FlaubertWithLMHeadModel,
)
Aymeric Augustin's avatar
Aymeric Augustin committed
96
97
98
99
100
from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model
from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_roberta import (
    ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
    RobertaForMaskedLM,
Julien Chaumond's avatar
Julien Chaumond committed
101
    RobertaForMultipleChoice,
Julien Chaumond's avatar
Julien Chaumond committed
102
    RobertaForQuestionAnswering,
Aymeric Augustin's avatar
Aymeric Augustin committed
103
104
105
    RobertaForSequenceClassification,
    RobertaForTokenClassification,
    RobertaModel,
106
)
107
from .modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5ForConditionalGeneration, T5Model
Aymeric Augustin's avatar
Aymeric Augustin committed
108
109
110
from .modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TransfoXLModel
from .modeling_xlm import (
    XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
111
    XLMForQuestionAnsweringSimple,
Aymeric Augustin's avatar
Aymeric Augustin committed
112
    XLMForSequenceClassification,
113
    XLMForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
114
115
    XLMModel,
    XLMWithLMHeadModel,
116
117
)
from .modeling_xlm_roberta import (
Aymeric Augustin's avatar
Aymeric Augustin committed
118
    XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
119
    XLMRobertaForMaskedLM,
Julien Chaumond's avatar
Julien Chaumond committed
120
    XLMRobertaForMultipleChoice,
Aymeric Augustin's avatar
Aymeric Augustin committed
121
    XLMRobertaForSequenceClassification,
122
    XLMRobertaForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
123
124
125
126
    XLMRobertaModel,
)
from .modeling_xlnet import (
    XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
Julien Chaumond's avatar
Julien Chaumond committed
127
    XLNetForMultipleChoice,
128
    XLNetForQuestionAnsweringSimple,
Aymeric Augustin's avatar
Aymeric Augustin committed
129
130
131
132
    XLNetForSequenceClassification,
    XLNetForTokenClassification,
    XLNetLMHeadModel,
    XLNetModel,
133
)
thomwolf's avatar
thomwolf committed
134

thomwolf's avatar
thomwolf committed
135

136
logger = logging.getLogger(__name__)
thomwolf's avatar
thomwolf committed
137
138


139
140
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
    (key, value)
141
142
    for pretrained_map in [
        BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
Sam Shleifer's avatar
Sam Shleifer committed
143
        BART_PRETRAINED_MODEL_ARCHIVE_MAP,
144
145
146
147
148
149
150
151
152
153
154
        OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
        TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
        GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
        CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
        XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
        XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
        ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
        DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
        ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
        CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
        T5_PRETRAINED_MODEL_ARCHIVE_MAP,
Lysandre's avatar
Lysandre committed
155
        FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
Lysandre's avatar
Lysandre committed
156
        XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
Lysandre Debut's avatar
Lysandre Debut committed
157
        ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
158
159
160
    ]
    for key, value, in pretrained_map.items()
)
161

Julien Chaumond's avatar
Julien Chaumond committed
162
MODEL_MAPPING = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
163
164
165
166
167
    [
        (T5Config, T5Model),
        (DistilBertConfig, DistilBertModel),
        (AlbertConfig, AlbertModel),
        (CamembertConfig, CamembertModel),
168
        (XLMRobertaConfig, XLMRobertaModel),
Sam Shleifer's avatar
Sam Shleifer committed
169
        (BartConfig, BartModel),
170
        (RobertaConfig, RobertaModel),
Julien Chaumond's avatar
Julien Chaumond committed
171
172
173
174
175
        (BertConfig, BertModel),
        (OpenAIGPTConfig, OpenAIGPTModel),
        (GPT2Config, GPT2Model),
        (TransfoXLConfig, TransfoXLModel),
        (XLNetConfig, XLNetModel),
Lysandre's avatar
Lysandre committed
176
        (FlaubertConfig, FlaubertModel),
Julien Chaumond's avatar
Julien Chaumond committed
177
178
        (XLMConfig, XLMModel),
        (CTRLConfig, CTRLModel),
Lysandre Debut's avatar
Lysandre Debut committed
179
        (ElectraConfig, ElectraModel),
Julien Chaumond's avatar
Julien Chaumond committed
180
181
182
    ]
)

thomwolf's avatar
thomwolf committed
183
184
MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
    [
185
        (T5Config, T5ForConditionalGeneration),
thomwolf's avatar
thomwolf committed
186
187
188
189
        (DistilBertConfig, DistilBertForMaskedLM),
        (AlbertConfig, AlbertForMaskedLM),
        (CamembertConfig, CamembertForMaskedLM),
        (XLMRobertaConfig, XLMRobertaForMaskedLM),
190
        (BartConfig, BartForConditionalGeneration),
thomwolf's avatar
thomwolf committed
191
192
193
194
195
196
        (RobertaConfig, RobertaForMaskedLM),
        (BertConfig, BertForPreTraining),
        (OpenAIGPTConfig, OpenAIGPTLMHeadModel),
        (GPT2Config, GPT2LMHeadModel),
        (TransfoXLConfig, TransfoXLLMHeadModel),
        (XLNetConfig, XLNetLMHeadModel),
Lysandre's avatar
Lysandre committed
197
        (FlaubertConfig, FlaubertWithLMHeadModel),
thomwolf's avatar
thomwolf committed
198
199
        (XLMConfig, XLMWithLMHeadModel),
        (CTRLConfig, CTRLLMHeadModel),
Lysandre Debut's avatar
Lysandre Debut committed
200
        (ElectraConfig, ElectraForPreTraining),
thomwolf's avatar
thomwolf committed
201
202
203
    ]
)

Julien Chaumond's avatar
Julien Chaumond committed
204
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
205
    [
206
        (T5Config, T5ForConditionalGeneration),
207
208
209
210
        (DistilBertConfig, DistilBertForMaskedLM),
        (AlbertConfig, AlbertForMaskedLM),
        (CamembertConfig, CamembertForMaskedLM),
        (XLMRobertaConfig, XLMRobertaForMaskedLM),
211
        (BartConfig, BartForConditionalGeneration),
212
        (RobertaConfig, RobertaForMaskedLM),
213
214
215
216
217
        (BertConfig, BertForMaskedLM),
        (OpenAIGPTConfig, OpenAIGPTLMHeadModel),
        (GPT2Config, GPT2LMHeadModel),
        (TransfoXLConfig, TransfoXLLMHeadModel),
        (XLNetConfig, XLNetLMHeadModel),
Lysandre's avatar
Lysandre committed
218
        (FlaubertConfig, FlaubertWithLMHeadModel),
219
220
        (XLMConfig, XLMWithLMHeadModel),
        (CTRLConfig, CTRLLMHeadModel),
Lysandre Debut's avatar
Lysandre Debut committed
221
        (ElectraConfig, ElectraForMaskedLM),
222
223
224
    ]
)

Julien Chaumond's avatar
Julien Chaumond committed
225
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
226
227
228
229
230
    [
        (DistilBertConfig, DistilBertForSequenceClassification),
        (AlbertConfig, AlbertForSequenceClassification),
        (CamembertConfig, CamembertForSequenceClassification),
        (XLMRobertaConfig, XLMRobertaForSequenceClassification),
Sam Shleifer's avatar
Sam Shleifer committed
231
        (BartConfig, BartForSequenceClassification),
232
        (RobertaConfig, RobertaForSequenceClassification),
233
234
        (BertConfig, BertForSequenceClassification),
        (XLNetConfig, XLNetForSequenceClassification),
Lysandre's avatar
Lysandre committed
235
        (FlaubertConfig, FlaubertForSequenceClassification),
Lysandre's avatar
Lysandre committed
236
        (XLMConfig, XLMForSequenceClassification),
237
238
239
    ]
)

Julien Chaumond's avatar
Julien Chaumond committed
240
MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
241
242
243
    [
        (DistilBertConfig, DistilBertForQuestionAnswering),
        (AlbertConfig, AlbertForQuestionAnswering),
Malte Pietsch's avatar
Malte Pietsch committed
244
        (RobertaConfig, RobertaForQuestionAnswering),
245
        (BertConfig, BertForQuestionAnswering),
246
247
248
        (XLNetConfig, XLNetForQuestionAnsweringSimple),
        (FlaubertConfig, FlaubertForQuestionAnsweringSimple),
        (XLMConfig, XLMForQuestionAnsweringSimple),
249
250
251
    ]
)

Julien Chaumond's avatar
Julien Chaumond committed
252
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
253
254
255
    [
        (DistilBertConfig, DistilBertForTokenClassification),
        (CamembertConfig, CamembertForTokenClassification),
256
        (XLMConfig, XLMForTokenClassification),
257
        (XLMRobertaConfig, XLMRobertaForTokenClassification),
258
        (RobertaConfig, RobertaForTokenClassification),
Julien Chaumond's avatar
Julien Chaumond committed
259
260
        (BertConfig, BertForTokenClassification),
        (XLNetConfig, XLNetForTokenClassification),
261
        (AlbertConfig, AlbertForTokenClassification),
Lysandre Debut's avatar
Lysandre Debut committed
262
        (ElectraConfig, ElectraForTokenClassification),
Julien Chaumond's avatar
Julien Chaumond committed
263
264
265
    ]
)

266

Julien Chaumond's avatar
Julien Chaumond committed
267
268
269
270
271
272
273
274
275
276
277
278
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
    [
        (CamembertConfig, CamembertForMultipleChoice),
        (XLMRobertaConfig, XLMRobertaForMultipleChoice),
        (RobertaConfig, RobertaForMultipleChoice),
        (BertConfig, BertForMultipleChoice),
        (XLNetConfig, XLNetForMultipleChoice),
    ]
)


class AutoModel:
thomwolf's avatar
thomwolf committed
279
    r"""
280
        :class:`~transformers.AutoModel` is a generic model class
thomwolf's avatar
thomwolf committed
281
282
        that will be instantiated as one of the base model classes of the library
        when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
283
        or the `AutoModel.from_config(config)` class methods.
thomwolf's avatar
thomwolf committed
284

285
        This class cannot be instantiated using `__init__()` (throws an error).
thomwolf's avatar
thomwolf committed
286
    """
287

thomwolf's avatar
thomwolf committed
288
    def __init__(self):
289
290
        raise EnvironmentError(
            "AutoModel is designed to be instantiated "
291
            "using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` or "
292
293
            "`AutoModel.from_config(config)` methods."
        )
294
295
296
297
298
299

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

Lysandre's avatar
Lysandre committed
300
301
        Args:
            config (:class:`~transformers.PretrainedConfig`):
302
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
303
304
305
306
307
308
309
310
311
312

                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModel` (DistilBERT model)
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModel` (RoBERTa model)
                - isInstance of `bert` configuration class: :class:`~transformers.BertModel` (Bert model)
                - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model)
                - isInstance of `gpt2` configuration class: :class:`~transformers.GPT2Model` (OpenAI GPT-2 model)
                - isInstance of `ctrl` configuration class: :class:`~transformers.CTRLModel` (Salesforce CTRL  model)
                - isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLModel` (Transformer-XL model)
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModel` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMModel` (XLM model)
Lysandre Debut's avatar
Lysandre Debut committed
313
314
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertModel` (Flaubert model)
                - isInstance of `electra` configuration class: :class:`~transformers.ElectraModel` (Electra model)
315
316
317
318
319
320

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModel.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
Julien Chaumond's avatar
Julien Chaumond committed
321
322
323
        for config_class, model_class in MODEL_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)
324
325
326
327
328
329
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
            )
        )
thomwolf's avatar
thomwolf committed
330
331
332

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
333
        r""" Instantiates one of the base model classes of the library
thomwolf's avatar
thomwolf committed
334
335
        from a pre-trained model configuration.

Lysandre's avatar
Lysandre committed
336
337
338
339
340
        The `from_pretrained()` method takes care of returning the correct model class instance
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.

        The base model class to instantiate is selected as the first pattern matching
thomwolf's avatar
thomwolf committed
341
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
342
343
344
345
346
347
348
349
350
351
352
353
354
            - contains `t5`: :class:`~transformers.T5Model` (T5 model)
            - contains `distilbert`: :class:`~transformers.DistilBertModel` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertModel` (ALBERT model)
            - contains `camembert`: :class:`~transformers.CamembertModel` (CamemBERT model)
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaModel` (XLM-RoBERTa model)
            - contains `roberta`: :class:`~transformers.RobertaModel` (RoBERTa model)
            - contains `bert`: :class:`~transformers.BertModel` (Bert model)
            - contains `openai-gpt`: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model)
            - contains `gpt2`: :class:`~transformers.GPT2Model` (OpenAI GPT-2 model)
            - contains `transfo-xl`: :class:`~transformers.TransfoXLModel` (Transformer-XL model)
            - contains `xlnet`: :class:`~transformers.XLNetModel` (XLNet model)
            - contains `xlm`: :class:`~transformers.XLMModel` (XLM model)
            - contains `ctrl`: :class:`~transformers.CTRLModel` (Salesforce CTRL  model)
Lysandre Debut's avatar
Lysandre Debut committed
355
356
            - contains `flaubert`: :class:`~transformers.FlaubertModel` (Flaubert  model)
            - contains `electra`: :class:`~transformers.ElectraModel` (Electra  model)
thomwolf's avatar
thomwolf committed
357

thomwolf's avatar
typos  
thomwolf committed
358
            The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
thomwolf's avatar
thomwolf committed
359
360
            To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
361
        Args:
thomwolf's avatar
thomwolf committed
362
363
364
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
365
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
366
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
367
368
369
370
371
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

372
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
373
374
375
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
376
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
377
378
379
380
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
thomwolf's avatar
typos  
thomwolf committed
381
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
382
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
383
384

            cache_dir: (`optional`) string:
thomwolf's avatar
thomwolf committed
385
386
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
387
388
389
390

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

391
392
393
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

thomwolf's avatar
thomwolf committed
394
395
396
397
398
399
400
401
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
402
                These arguments will be passed to the configuration and the model.
thomwolf's avatar
thomwolf committed
403
404
405

        Examples::

thomwolf's avatar
thomwolf committed
406
407
408
409
410
411
            model = AutoModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModel.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
thomwolf's avatar
thomwolf committed
412
413

        """
414
415
416
417
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

Julien Chaumond's avatar
Julien Chaumond committed
418
419
420
        for config_class, model_class in MODEL_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
421
        raise ValueError(
422
423
424
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
425
426
            )
        )
427
428


Julien Chaumond's avatar
Julien Chaumond committed
429
class AutoModelForPreTraining:
thomwolf's avatar
thomwolf committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
    r"""
        :class:`~transformers.AutoModelForPreTraining` is a generic model class
        that will be instantiated as one of the model classes of the library -with the architecture used for pretraining this model– when created with the `AutoModelForPreTraining.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """

    def __init__(self):
        raise EnvironmentError(
            "AutoModelForPreTraining is designed to be instantiated "
            "using the `AutoModelForPreTraining.from_pretrained(pretrained_model_name_or_path)` or "
            "`AutoModelForPreTraining.from_config(config)` methods."
        )

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

        Args:
            config (:class:`~transformers.PretrainedConfig`):
                The model class to instantiate is selected based on the configuration class:

454
455
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
thomwolf's avatar
thomwolf committed
456
457
                - isInstance of `bert` configuration class: :class:`~transformers.BertForPreTraining` (Bert model)
                - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
458
459
                - isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
                - isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL  model)
thomwolf's avatar
thomwolf committed
460
461
462
                - isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
Lysandre's avatar
Lysandre committed
463
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
Lysandre Debut's avatar
Lysandre Debut committed
464
                - isInstance of `electra` configuration class: :class:`~transformers.ElectraForPreTraining` (Electra model)
thomwolf's avatar
thomwolf committed
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelForPreTraining.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
        for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_PRETRAINING_MAPPING.keys())
            )
        )

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the model classes of the library -with the architecture used for pretraining this model– from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
            - contains `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model)
            - contains `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
            - contains `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
            - contains `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
            - contains `bert`: :class:`~transformers.BertForPreTraining` (Bert model)
            - contains `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
            - contains `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
            - contains `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
            - contains `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
            - contains `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
            - contains `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
Lysandre's avatar
Lysandre committed
504
            - contains `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
Lysandre Debut's avatar
Lysandre Debut committed
505
            - contains `electra`: :class:`~transformers.ElectraForPreTraining` (Electra model)
thomwolf's avatar
thomwolf committed
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

        Args:
            pretrained_model_name_or_path:
                Either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely received file. Attempt to resume the download if such a file exists.
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.
            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
544
                These arguments will be passed to the configuration and the model.
thomwolf's avatar
thomwolf committed
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570

        Examples::

            model = AutoModelForPreTraining.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForPreTraining.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForPreTraining.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

        for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_PRETRAINING_MAPPING.keys())
            )
        )


Julien Chaumond's avatar
Julien Chaumond committed
571
class AutoModelWithLMHead:
572
    r"""
573
        :class:`~transformers.AutoModelWithLMHead` is a generic model class
574
575
576
577
578
579
        that will be instantiated as one of the language modeling model classes of the library
        when created with the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """
580

581
    def __init__(self):
582
583
        raise EnvironmentError(
            "AutoModelWithLMHead is designed to be instantiated "
584
            "using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` or "
585
586
            "`AutoModelWithLMHead.from_config(config)` methods."
        )
587
588
589
590
591
592

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

Lysandre's avatar
Lysandre committed
593
594
        Args:
            config (:class:`~transformers.PretrainedConfig`):
595
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
596

597
598
599
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
                - isInstance of `bert` configuration class: :class:`~transformers.BertForMaskedLM` (Bert model)
Lysandre's avatar
Lysandre committed
600
                - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
601
602
                - isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
                - isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL  model)
Lysandre's avatar
Lysandre committed
603
604
605
                - isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
Lysandre's avatar
Lysandre committed
606
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
Lysandre Debut's avatar
Lysandre Debut committed
607
                - isInstance of `electra` configuration class: :class:`~transformers.ElectraForMaskedLM` (Electra model)
608
609
610
611
612
613

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelWithLMHead.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
614
615
616
617
618
619
620
621
622
        for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())
            )
        )
623
624
625
626
627
628
629

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the language modeling model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
630
631
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.
632
633
634

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
635
636
637
638
639
640
641
642
643
644
645
646
647
            - contains `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model)
            - contains `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
            - contains `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
            - contains `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
            - contains `bert`: :class:`~transformers.BertForMaskedLM` (Bert model)
            - contains `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
            - contains `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
            - contains `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
            - contains `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
            - contains `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
            - contains `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
Lysandre's avatar
Lysandre committed
648
            - contains `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
Lysandre Debut's avatar
Lysandre Debut committed
649
            - contains `electra`: :class:`~transformers.ElectraForMaskedLM` (Electra model)
650
651
652
653

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
654
655
656
        Args:
            pretrained_model_name_or_path:
                Either:
thomwolf's avatar
thomwolf committed
657
658

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
659
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
660
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
661
662
663
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method
664
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
665
666
667
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
668
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
669
670
671
672
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
673
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
674
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
675
            cache_dir: (`optional`) string:
676
677
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
678
679
            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
680
            resume_download: (`optional`) boolean, default False:
Lysandre's avatar
Lysandre committed
681
                Do not delete incompletely received file. Attempt to resume the download if such a file exists.
thomwolf's avatar
thomwolf committed
682
683
684
685
686
687
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.
            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
688
                These arguments will be passed to the configuration and the model.
689
690
691
692
693
694
695
696
697
698
699

        Examples::

            model = AutoModelWithLMHead.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelWithLMHead.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
700
701
702
703
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

704
705
706
        for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
707
        raise ValueError(
708
709
710
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())
711
712
            )
        )
713
714


Julien Chaumond's avatar
Julien Chaumond committed
715
class AutoModelForSequenceClassification:
716
    r"""
717
        :class:`~transformers.AutoModelForSequenceClassification` is a generic model class
718
719
720
721
722
723
        that will be instantiated as one of the sequence classification model classes of the library
        when created with the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """
724

725
    def __init__(self):
726
727
        raise EnvironmentError(
            "AutoModelForSequenceClassification is designed to be instantiated "
728
            "using the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` or "
729
730
            "`AutoModelForSequenceClassification.from_config(config)` methods."
        )
731
732
733
734
735
736

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

Lysandre's avatar
Lysandre committed
737
738
        Args:
            config (:class:`~transformers.PretrainedConfig`):
739
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
740

741
742
743
744
745
746
747
748
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForSequenceClassification` (DistilBERT model)
                - isInstance of `albert` configuration class: :class:`~transformers.AlbertForSequenceClassification` (ALBERT model)
                - isInstance of `camembert` configuration class: :class:`~transformers.CamembertForSequenceClassification` (CamemBERT model)
                - isInstance of `xlm roberta` configuration class: :class:`~transformers.XLMRobertaForSequenceClassification` (XLM-RoBERTa model)
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaForSequenceClassification` (RoBERTa model)
                - isInstance of `bert` configuration class: :class:`~transformers.BertForSequenceClassification` (Bert model)
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetForSequenceClassification` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMForSequenceClassification` (XLM model)
Lysandre's avatar
Lysandre committed
749
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForSequenceClassification` (Flaubert model)
Lysandre's avatar
Lysandre committed
750

751
752
753
754
755
756

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelForSequenceClassification.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
757
758
759
760
761
762
763
764
765
766
767
        for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
            )
        )
768
769
770
771
772
773
774

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the sequence classification model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
775
776
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.
777
778
779

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
780
781
782
783
784
785
786
            - contains `distilbert`: :class:`~transformers.DistilBertForSequenceClassification` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertForSequenceClassification` (ALBERT model)
            - contains `camembert`: :class:`~transformers.CamembertForSequenceClassification` (CamemBERT model)
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaForSequenceClassification` (XLM-RoBERTa model)
            - contains `roberta`: :class:`~transformers.RobertaForSequenceClassification` (RoBERTa model)
            - contains `bert`: :class:`~transformers.BertForSequenceClassification` (Bert model)
            - contains `xlnet`: :class:`~transformers.XLNetForSequenceClassification` (XLNet model)
Lysandre's avatar
Lysandre committed
787
            - contains `flaubert`: :class:`~transformers.FlaubertForSequenceClassification` (Flaubert model)
788
789
790
791

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
792
        Args:
thomwolf's avatar
thomwolf committed
793
794
795
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
796
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
797
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
798
799
800
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
Lysandre's avatar
Lysandre committed
801
                All remaining positional arguments will be passed to the underlying model's ``__init__`` method
thomwolf's avatar
thomwolf committed
802

803
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
804
805
806
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
807
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
808
809
810
811
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
812
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
813
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
814
815

            cache_dir: (`optional`) string:
816
817
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
818
819
820
821

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

822
823
824
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

thomwolf's avatar
thomwolf committed
825
826
827
828
829
830
831
832
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
833
                These arguments will be passed to the configuration and the model.
834
835
836
837
838
839
840
841
842
843
844

        Examples::

            model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForSequenceClassification.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
845
846
847
848
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

849
850
851
        for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
852
        raise ValueError(
853
854
855
856
857
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
858
859
            )
        )
860
861


Julien Chaumond's avatar
Julien Chaumond committed
862
class AutoModelForQuestionAnswering:
863
    r"""
864
        :class:`~transformers.AutoModelForQuestionAnswering` is a generic model class
865
866
867
868
869
870
        that will be instantiated as one of the question answering model classes of the library
        when created with the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """
871

872
    def __init__(self):
873
874
        raise EnvironmentError(
            "AutoModelForQuestionAnswering is designed to be instantiated "
875
            "using the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or "
876
877
            "`AutoModelForQuestionAnswering.from_config(config)` methods."
        )
878
879
880
881
882
883

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

Lysandre's avatar
Lysandre committed
884
885
        Args:
            config (:class:`~transformers.PretrainedConfig`):
886
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
887

888
889
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForQuestionAnswering` (DistilBERT model)
                - isInstance of `albert` configuration class: :class:`~transformers.AlbertForQuestionAnswering` (ALBERT model)
Lysandre's avatar
Lysandre committed
890
                - isInstance of `bert` configuration class: :class:`~transformers.BertModelForQuestionAnswering` (Bert model)
891
892
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetForQuestionAnswering` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMForQuestionAnswering` (XLM model)
Lysandre's avatar
Lysandre committed
893
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForQuestionAnswering` (XLM model)
894
895
896
897
898
899

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelForSequenceClassification.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
900
901
902
903
904
905
906
907
908
909
910
911
        for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)

        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
            )
        )
912
913
914
915
916
917
918

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the question answering model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
919
920
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.
921
922
923

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
924
925
926
927
928
            - contains `distilbert`: :class:`~transformers.DistilBertForQuestionAnswering` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertForQuestionAnswering` (ALBERT model)
            - contains `bert`: :class:`~transformers.BertForQuestionAnswering` (Bert model)
            - contains `xlnet`: :class:`~transformers.XLNetForQuestionAnswering` (XLNet model)
            - contains `xlm`: :class:`~transformers.XLMForQuestionAnswering` (XLM model)
Lysandre's avatar
Lysandre committed
929
            - contains `flaubert`: :class:`~transformers.FlaubertForQuestionAnswering` (XLM model)
930
931
932
933

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
934
        Args:
thomwolf's avatar
thomwolf committed
935
936
937
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
938
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
939
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
940
941
942
943
944
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

945
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
946
947
948
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
949
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
950
951
952
953
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
954
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
955
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
956
957

            cache_dir: (`optional`) string:
958
959
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
960
961
962
963
964
965
966
967
968
969
970
971

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
972
                These arguments will be passed to the configuration and the model.
973
974
975
976
977
978
979
980
981
982
983

        Examples::

            model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForQuestionAnswering.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
984
985
986
987
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

988
989
990
        for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
991

992
        raise ValueError(
993
994
995
996
997
998
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
            )
999
        )
1000
1001
1002


class AutoModelForTokenClassification:
Lysandre's avatar
Lysandre committed
1003
1004
1005
1006
1007
1008
1009
1010
1011
    r"""
        :class:`~transformers.AutoModelForTokenClassification` is a generic model class
        that will be instantiated as one of the token classification model classes of the library
        when created with the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """

1012
    def __init__(self):
1013
1014
1015
1016
1017
        raise EnvironmentError(
            "AutoModelForTokenClassification is designed to be instantiated "
            "using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or "
            "`AutoModelForTokenClassification.from_config(config)` methods."
        )
1018
1019
1020
1021
1022

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.
1023

Lysandre's avatar
Lysandre committed
1024
1025
        Args:
            config (:class:`~transformers.PretrainedConfig`):
1026
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
1027

Lysandre's avatar
Lysandre committed
1028
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModelForTokenClassification` (DistilBERT model)
1029
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMForTokenClassification` (XLM model)
Lysandre's avatar
Lysandre committed
1030
1031
                - isInstance of `xlm roberta` configuration class: :class:`~transformers.XLMRobertaModelForTokenClassification` (XLMRoberta model)
                - isInstance of `bert` configuration class: :class:`~transformers.BertModelForTokenClassification` (Bert model)
1032
                - isInstance of `albert` configuration class: :class:`~transformers.AlbertForTokenClassification` (AlBert model)
Lysandre's avatar
Lysandre committed
1033
1034
1035
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModelForTokenClassification` (XLNet model)
                - isInstance of `camembert` configuration class: :class:`~transformers.CamembertModelForTokenClassification` (Camembert model)
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModelForTokenClassification` (Roberta model)
Lysandre Debut's avatar
Lysandre Debut committed
1036
                - isInstance of `electra` configuration class: :class:`~transformers.ElectraForTokenClassification` (Electra model)
1037

1038
        Examples::
1039

1040
1041
1042
            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelForTokenClassification.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
        for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)

        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
            )
        )
1055

1056
1057
1058
1059
1060
1061
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the question answering model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
1062
1063
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.
1064
1065
1066

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
1067
            - contains `distilbert`: :class:`~transformers.DistilBertForTokenClassification` (DistilBERT model)
1068
            - contains `xlm`: :class:`~transformers.XLMForTokenClassification` (XLM model)
Lysandre's avatar
Lysandre committed
1069
1070
1071
1072
1073
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaForTokenClassification` (XLM-RoBERTa?Para model)
            - contains `camembert`: :class:`~transformers.CamembertForTokenClassification` (Camembert model)
            - contains `bert`: :class:`~transformers.BertForTokenClassification` (Bert model)
            - contains `xlnet`: :class:`~transformers.XLNetForTokenClassification` (XLNet model)
            - contains `roberta`: :class:`~transformers.RobertaForTokenClassification` (Roberta model)
Lysandre Debut's avatar
Lysandre Debut committed
1074
            - contains `electra`: :class:`~transformers.ElectraForTokenClassification` (Electra model)
1075
1076
1077
1078

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
1079
1080
1081
        Args:
            pretrained_model_name_or_path:
                Either:
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
1117
                These arguments will be passed to the configuration and the model.
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128

        Examples::

            model = AutoModelForTokenClassification.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForTokenClassification.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
1129
1130
1131
1132
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

1133
1134
1135
        for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
1136

1137
        raise ValueError(
1138
1139
1140
1141
1142
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
1143
1144
            )
        )
Julien Chaumond's avatar
Julien Chaumond committed
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196


class AutoModelForMultipleChoice:
    r"""
        :class:`~transformers.AutoModelForMultipleChoice` is a generic model class
        that will be instantiated as one of the multiple choice model classes of the library
        when created with the `AutoModelForMultipleChoice.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """

    def __init__(self):
        raise EnvironmentError(
            "AutoModelForMultipleChoice is designed to be instantiated "
            "using the `AutoModelForMultipleChoice.from_pretrained(pretrained_model_name_or_path)` or "
            "`AutoModelForMultipleChoice.from_config(config)` methods."
        )

    @classmethod
    def from_config(cls, config):
        for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)

        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
            )
        )

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

        for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)

        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
            )
        )