modeling_auto.py 70.7 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """


import logging
Julien Chaumond's avatar
Julien Chaumond committed
19
from collections import OrderedDict
thomwolf's avatar
thomwolf committed
20

21
22
from .configuration_auto import (
    AlbertConfig,
23
    AutoConfig,
Sam Shleifer's avatar
Sam Shleifer committed
24
    BartConfig,
25
26
27
28
    BertConfig,
    CamembertConfig,
    CTRLConfig,
    DistilBertConfig,
Lysandre Debut's avatar
Lysandre Debut committed
29
    ElectraConfig,
30
    EncoderDecoderConfig,
Lysandre's avatar
Lysandre committed
31
    FlaubertConfig,
32
    GPT2Config,
Iz Beltagy's avatar
Iz Beltagy committed
33
    LongformerConfig,
34
    OpenAIGPTConfig,
Patrick von Platen's avatar
Patrick von Platen committed
35
    ReformerConfig,
36
    RobertaConfig,
37
    T5Config,
38
39
40
    TransfoXLConfig,
    XLMConfig,
    XLMRobertaConfig,
Aymeric Augustin's avatar
Aymeric Augustin committed
41
42
    XLNetConfig,
)
43
from .configuration_marian import MarianConfig
44
from .configuration_utils import PretrainedConfig
Aymeric Augustin's avatar
Aymeric Augustin committed
45
46
47
from .modeling_albert import (
    ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
    AlbertForMaskedLM,
48
    AlbertForPreTraining,
Aymeric Augustin's avatar
Aymeric Augustin committed
49
50
    AlbertForQuestionAnswering,
    AlbertForSequenceClassification,
51
    AlbertForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
52
    AlbertModel,
53
)
54
55
56
57
58
59
from .modeling_bart import (
    BART_PRETRAINED_MODEL_ARCHIVE_MAP,
    BartForConditionalGeneration,
    BartForSequenceClassification,
    BartModel,
)
60
from .modeling_bert import (
Aymeric Augustin's avatar
Aymeric Augustin committed
61
    BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
62
    BertForMaskedLM,
Julien Chaumond's avatar
Julien Chaumond committed
63
    BertForMultipleChoice,
thomwolf's avatar
thomwolf committed
64
    BertForPreTraining,
65
    BertForQuestionAnswering,
Aymeric Augustin's avatar
Aymeric Augustin committed
66
    BertForSequenceClassification,
67
    BertForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
68
    BertModel,
69
)
Aymeric Augustin's avatar
Aymeric Augustin committed
70
71
72
from .modeling_camembert import (
    CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
    CamembertForMaskedLM,
Julien Chaumond's avatar
Julien Chaumond committed
73
    CamembertForMultipleChoice,
Aymeric Augustin's avatar
Aymeric Augustin committed
74
75
76
    CamembertForSequenceClassification,
    CamembertForTokenClassification,
    CamembertModel,
77
)
Aymeric Augustin's avatar
Aymeric Augustin committed
78
from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel, CTRLModel
79
from .modeling_distilbert import (
Aymeric Augustin's avatar
Aymeric Augustin committed
80
    DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
81
    DistilBertForMaskedLM,
Aymeric Augustin's avatar
Aymeric Augustin committed
82
    DistilBertForQuestionAnswering,
83
84
    DistilBertForSequenceClassification,
    DistilBertForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
85
    DistilBertModel,
86
)
Lysandre Debut's avatar
Lysandre Debut committed
87
88
89
90
from .modeling_electra import (
    ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
    ElectraForMaskedLM,
    ElectraForPreTraining,
91
    ElectraForSequenceClassification,
Lysandre Debut's avatar
Lysandre Debut committed
92
93
94
    ElectraForTokenClassification,
    ElectraModel,
)
95
from .modeling_encoder_decoder import EncoderDecoderModel
Lysandre's avatar
Lysandre committed
96
97
from .modeling_flaubert import (
    FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
98
    FlaubertForQuestionAnsweringSimple,
Lysandre's avatar
Lysandre committed
99
100
101
102
    FlaubertForSequenceClassification,
    FlaubertModel,
    FlaubertWithLMHeadModel,
)
Aymeric Augustin's avatar
Aymeric Augustin committed
103
from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model
104
105
106
107
from .modeling_longformer import (
    LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
    LongformerForMaskedLM,
    LongformerForQuestionAnswering,
108
    LongformerForSequenceClassification,
109
110
    LongformerModel,
)
111
from .modeling_marian import MarianMTModel
Aymeric Augustin's avatar
Aymeric Augustin committed
112
from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel
Patrick von Platen's avatar
Patrick von Platen committed
113
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
Aymeric Augustin's avatar
Aymeric Augustin committed
114
115
116
from .modeling_roberta import (
    ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
    RobertaForMaskedLM,
Julien Chaumond's avatar
Julien Chaumond committed
117
    RobertaForMultipleChoice,
Julien Chaumond's avatar
Julien Chaumond committed
118
    RobertaForQuestionAnswering,
Aymeric Augustin's avatar
Aymeric Augustin committed
119
120
121
    RobertaForSequenceClassification,
    RobertaForTokenClassification,
    RobertaModel,
122
)
123
from .modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5ForConditionalGeneration, T5Model
Aymeric Augustin's avatar
Aymeric Augustin committed
124
125
126
from .modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TransfoXLModel
from .modeling_xlm import (
    XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
127
    XLMForQuestionAnsweringSimple,
Aymeric Augustin's avatar
Aymeric Augustin committed
128
    XLMForSequenceClassification,
129
    XLMForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
130
131
    XLMModel,
    XLMWithLMHeadModel,
132
133
)
from .modeling_xlm_roberta import (
Aymeric Augustin's avatar
Aymeric Augustin committed
134
    XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
135
    XLMRobertaForMaskedLM,
Julien Chaumond's avatar
Julien Chaumond committed
136
    XLMRobertaForMultipleChoice,
Aymeric Augustin's avatar
Aymeric Augustin committed
137
    XLMRobertaForSequenceClassification,
138
    XLMRobertaForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
139
140
141
142
    XLMRobertaModel,
)
from .modeling_xlnet import (
    XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
Julien Chaumond's avatar
Julien Chaumond committed
143
    XLNetForMultipleChoice,
144
    XLNetForQuestionAnsweringSimple,
Aymeric Augustin's avatar
Aymeric Augustin committed
145
146
147
148
    XLNetForSequenceClassification,
    XLNetForTokenClassification,
    XLNetLMHeadModel,
    XLNetModel,
149
)
thomwolf's avatar
thomwolf committed
150

thomwolf's avatar
thomwolf committed
151

152
logger = logging.getLogger(__name__)
thomwolf's avatar
thomwolf committed
153
154


155
156
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
    (key, value)
157
158
    for pretrained_map in [
        BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
Sam Shleifer's avatar
Sam Shleifer committed
159
        BART_PRETRAINED_MODEL_ARCHIVE_MAP,
160
161
162
163
164
165
166
167
168
169
170
        OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
        TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
        GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
        CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
        XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
        XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
        ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
        DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
        ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
        CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
        T5_PRETRAINED_MODEL_ARCHIVE_MAP,
Lysandre's avatar
Lysandre committed
171
        FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
Lysandre's avatar
Lysandre committed
172
        XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
Lysandre Debut's avatar
Lysandre Debut committed
173
        ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
Iz Beltagy's avatar
Iz Beltagy committed
174
        LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
175
176
177
    ]
    for key, value, in pretrained_map.items()
)
178

Julien Chaumond's avatar
Julien Chaumond committed
179
MODEL_MAPPING = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
180
181
182
183
184
    [
        (T5Config, T5Model),
        (DistilBertConfig, DistilBertModel),
        (AlbertConfig, AlbertModel),
        (CamembertConfig, CamembertModel),
185
        (XLMRobertaConfig, XLMRobertaModel),
Sam Shleifer's avatar
Sam Shleifer committed
186
        (BartConfig, BartModel),
Iz Beltagy's avatar
Iz Beltagy committed
187
        (LongformerConfig, LongformerModel),
188
        (RobertaConfig, RobertaModel),
Julien Chaumond's avatar
Julien Chaumond committed
189
190
191
192
193
        (BertConfig, BertModel),
        (OpenAIGPTConfig, OpenAIGPTModel),
        (GPT2Config, GPT2Model),
        (TransfoXLConfig, TransfoXLModel),
        (XLNetConfig, XLNetModel),
Lysandre's avatar
Lysandre committed
194
        (FlaubertConfig, FlaubertModel),
Julien Chaumond's avatar
Julien Chaumond committed
195
196
        (XLMConfig, XLMModel),
        (CTRLConfig, CTRLModel),
Lysandre Debut's avatar
Lysandre Debut committed
197
        (ElectraConfig, ElectraModel),
Patrick von Platen's avatar
Patrick von Platen committed
198
        (ReformerConfig, ReformerModel),
Julien Chaumond's avatar
Julien Chaumond committed
199
200
201
    ]
)

thomwolf's avatar
thomwolf committed
202
203
MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
    [
204
        (T5Config, T5ForConditionalGeneration),
thomwolf's avatar
thomwolf committed
205
        (DistilBertConfig, DistilBertForMaskedLM),
206
        (AlbertConfig, AlbertForPreTraining),
thomwolf's avatar
thomwolf committed
207
208
        (CamembertConfig, CamembertForMaskedLM),
        (XLMRobertaConfig, XLMRobertaForMaskedLM),
209
        (BartConfig, BartForConditionalGeneration),
Iz Beltagy's avatar
Iz Beltagy committed
210
        (LongformerConfig, LongformerForMaskedLM),
thomwolf's avatar
thomwolf committed
211
212
213
214
215
216
        (RobertaConfig, RobertaForMaskedLM),
        (BertConfig, BertForPreTraining),
        (OpenAIGPTConfig, OpenAIGPTLMHeadModel),
        (GPT2Config, GPT2LMHeadModel),
        (TransfoXLConfig, TransfoXLLMHeadModel),
        (XLNetConfig, XLNetLMHeadModel),
Lysandre's avatar
Lysandre committed
217
        (FlaubertConfig, FlaubertWithLMHeadModel),
thomwolf's avatar
thomwolf committed
218
219
        (XLMConfig, XLMWithLMHeadModel),
        (CTRLConfig, CTRLLMHeadModel),
Lysandre Debut's avatar
Lysandre Debut committed
220
        (ElectraConfig, ElectraForPreTraining),
thomwolf's avatar
thomwolf committed
221
222
223
    ]
)

Julien Chaumond's avatar
Julien Chaumond committed
224
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
225
    [
226
        (T5Config, T5ForConditionalGeneration),
227
228
229
230
        (DistilBertConfig, DistilBertForMaskedLM),
        (AlbertConfig, AlbertForMaskedLM),
        (CamembertConfig, CamembertForMaskedLM),
        (XLMRobertaConfig, XLMRobertaForMaskedLM),
231
        (MarianConfig, MarianMTModel),
232
        (BartConfig, BartForConditionalGeneration),
Iz Beltagy's avatar
Iz Beltagy committed
233
        (LongformerConfig, LongformerForMaskedLM),
234
        (RobertaConfig, RobertaForMaskedLM),
235
236
237
238
239
        (BertConfig, BertForMaskedLM),
        (OpenAIGPTConfig, OpenAIGPTLMHeadModel),
        (GPT2Config, GPT2LMHeadModel),
        (TransfoXLConfig, TransfoXLLMHeadModel),
        (XLNetConfig, XLNetLMHeadModel),
Lysandre's avatar
Lysandre committed
240
        (FlaubertConfig, FlaubertWithLMHeadModel),
241
242
        (XLMConfig, XLMWithLMHeadModel),
        (CTRLConfig, CTRLLMHeadModel),
Lysandre Debut's avatar
Lysandre Debut committed
243
        (ElectraConfig, ElectraForMaskedLM),
244
        (EncoderDecoderConfig, EncoderDecoderModel),
Patrick von Platen's avatar
Patrick von Platen committed
245
        (ReformerConfig, ReformerModelWithLMHead),
246
247
248
    ]
)

Julien Chaumond's avatar
Julien Chaumond committed
249
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
250
251
252
253
254
    [
        (DistilBertConfig, DistilBertForSequenceClassification),
        (AlbertConfig, AlbertForSequenceClassification),
        (CamembertConfig, CamembertForSequenceClassification),
        (XLMRobertaConfig, XLMRobertaForSequenceClassification),
Sam Shleifer's avatar
Sam Shleifer committed
255
        (BartConfig, BartForSequenceClassification),
256
        (LongformerConfig, LongformerForSequenceClassification),
257
        (RobertaConfig, RobertaForSequenceClassification),
258
259
        (BertConfig, BertForSequenceClassification),
        (XLNetConfig, XLNetForSequenceClassification),
Lysandre's avatar
Lysandre committed
260
        (FlaubertConfig, FlaubertForSequenceClassification),
Lysandre's avatar
Lysandre committed
261
        (XLMConfig, XLMForSequenceClassification),
262
        (ElectraConfig, ElectraForSequenceClassification),
263
264
265
    ]
)

Julien Chaumond's avatar
Julien Chaumond committed
266
MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
267
268
269
    [
        (DistilBertConfig, DistilBertForQuestionAnswering),
        (AlbertConfig, AlbertForQuestionAnswering),
270
        (LongformerConfig, LongformerForQuestionAnswering),
Malte Pietsch's avatar
Malte Pietsch committed
271
        (RobertaConfig, RobertaForQuestionAnswering),
272
        (BertConfig, BertForQuestionAnswering),
273
274
275
        (XLNetConfig, XLNetForQuestionAnsweringSimple),
        (FlaubertConfig, FlaubertForQuestionAnsweringSimple),
        (XLMConfig, XLMForQuestionAnsweringSimple),
276
277
278
    ]
)

Julien Chaumond's avatar
Julien Chaumond committed
279
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
280
281
282
    [
        (DistilBertConfig, DistilBertForTokenClassification),
        (CamembertConfig, CamembertForTokenClassification),
283
        (XLMConfig, XLMForTokenClassification),
284
        (XLMRobertaConfig, XLMRobertaForTokenClassification),
285
        (RobertaConfig, RobertaForTokenClassification),
Julien Chaumond's avatar
Julien Chaumond committed
286
287
        (BertConfig, BertForTokenClassification),
        (XLNetConfig, XLNetForTokenClassification),
288
        (AlbertConfig, AlbertForTokenClassification),
Lysandre Debut's avatar
Lysandre Debut committed
289
        (ElectraConfig, ElectraForTokenClassification),
Julien Chaumond's avatar
Julien Chaumond committed
290
291
292
    ]
)

293

Julien Chaumond's avatar
Julien Chaumond committed
294
295
296
297
298
299
300
301
302
303
304
305
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
    [
        (CamembertConfig, CamembertForMultipleChoice),
        (XLMRobertaConfig, XLMRobertaForMultipleChoice),
        (RobertaConfig, RobertaForMultipleChoice),
        (BertConfig, BertForMultipleChoice),
        (XLNetConfig, XLNetForMultipleChoice),
    ]
)


class AutoModel:
thomwolf's avatar
thomwolf committed
306
    r"""
307
        :class:`~transformers.AutoModel` is a generic model class
thomwolf's avatar
thomwolf committed
308
309
        that will be instantiated as one of the base model classes of the library
        when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
310
        or the `AutoModel.from_config(config)` class methods.
thomwolf's avatar
thomwolf committed
311

312
        This class cannot be instantiated using `__init__()` (throws an error).
thomwolf's avatar
thomwolf committed
313
    """
314

thomwolf's avatar
thomwolf committed
315
    def __init__(self):
316
317
        raise EnvironmentError(
            "AutoModel is designed to be instantiated "
318
            "using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` or "
319
320
            "`AutoModel.from_config(config)` methods."
        )
321
322
323
324
325
326

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

Lysandre's avatar
Lysandre committed
327
328
        Args:
            config (:class:`~transformers.PretrainedConfig`):
329
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
330
331

                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModel` (DistilBERT model)
Iz Beltagy's avatar
Iz Beltagy committed
332
                - isInstance of `longformer` configuration class: :class:`~transformers.LongformerModel` (Longformer model)
Lysandre's avatar
Lysandre committed
333
334
335
336
337
338
339
340
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModel` (RoBERTa model)
                - isInstance of `bert` configuration class: :class:`~transformers.BertModel` (Bert model)
                - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model)
                - isInstance of `gpt2` configuration class: :class:`~transformers.GPT2Model` (OpenAI GPT-2 model)
                - isInstance of `ctrl` configuration class: :class:`~transformers.CTRLModel` (Salesforce CTRL  model)
                - isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLModel` (Transformer-XL model)
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModel` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMModel` (XLM model)
Lysandre Debut's avatar
Lysandre Debut committed
341
342
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertModel` (Flaubert model)
                - isInstance of `electra` configuration class: :class:`~transformers.ElectraModel` (Electra model)
343
344
345
346
347
348

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModel.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
Julien Chaumond's avatar
Julien Chaumond committed
349
350
351
        for config_class, model_class in MODEL_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)
352
353
354
355
356
357
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
            )
        )
thomwolf's avatar
thomwolf committed
358
359
360

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
361
        r""" Instantiates one of the base model classes of the library
thomwolf's avatar
thomwolf committed
362
363
        from a pre-trained model configuration.

Lysandre's avatar
Lysandre committed
364
365
366
367
368
        The `from_pretrained()` method takes care of returning the correct model class instance
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.

        The base model class to instantiate is selected as the first pattern matching
thomwolf's avatar
thomwolf committed
369
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
370
371
372
373
374
            - contains `t5`: :class:`~transformers.T5Model` (T5 model)
            - contains `distilbert`: :class:`~transformers.DistilBertModel` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertModel` (ALBERT model)
            - contains `camembert`: :class:`~transformers.CamembertModel` (CamemBERT model)
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaModel` (XLM-RoBERTa model)
Iz Beltagy's avatar
Iz Beltagy committed
375
            - contains `longformer` :class:`~transformers.LongformerModel` (Longformer model)
Lysandre's avatar
Lysandre committed
376
377
378
379
380
381
382
383
            - contains `roberta`: :class:`~transformers.RobertaModel` (RoBERTa model)
            - contains `bert`: :class:`~transformers.BertModel` (Bert model)
            - contains `openai-gpt`: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model)
            - contains `gpt2`: :class:`~transformers.GPT2Model` (OpenAI GPT-2 model)
            - contains `transfo-xl`: :class:`~transformers.TransfoXLModel` (Transformer-XL model)
            - contains `xlnet`: :class:`~transformers.XLNetModel` (XLNet model)
            - contains `xlm`: :class:`~transformers.XLMModel` (XLM model)
            - contains `ctrl`: :class:`~transformers.CTRLModel` (Salesforce CTRL  model)
Lysandre Debut's avatar
Lysandre Debut committed
384
385
            - contains `flaubert`: :class:`~transformers.FlaubertModel` (Flaubert  model)
            - contains `electra`: :class:`~transformers.ElectraModel` (Electra  model)
thomwolf's avatar
thomwolf committed
386

thomwolf's avatar
typos  
thomwolf committed
387
            The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
thomwolf's avatar
thomwolf committed
388
389
            To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
390
        Args:
thomwolf's avatar
thomwolf committed
391
392
393
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
394
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
395
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
396
397
398
399
400
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

401
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
402
403
404
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
405
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
406
407
408
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
409
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
thomwolf's avatar
typos  
thomwolf committed
410
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
411
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
412
413

            cache_dir: (`optional`) string:
thomwolf's avatar
thomwolf committed
414
415
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
416
417
418
419

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

420
421
422
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

thomwolf's avatar
thomwolf committed
423
424
425
426
427
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
428
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
thomwolf's avatar
thomwolf committed
429
430

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
431
                These arguments will be passed to the configuration and the model.
thomwolf's avatar
thomwolf committed
432
433
434

        Examples::

thomwolf's avatar
thomwolf committed
435
436
437
438
439
440
            model = AutoModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModel.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
thomwolf's avatar
thomwolf committed
441
442

        """
443
444
445
446
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

Julien Chaumond's avatar
Julien Chaumond committed
447
448
449
        for config_class, model_class in MODEL_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
450
        raise ValueError(
451
452
453
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
454
455
            )
        )
456
457


Julien Chaumond's avatar
Julien Chaumond committed
458
class AutoModelForPreTraining:
thomwolf's avatar
thomwolf committed
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    r"""
        :class:`~transformers.AutoModelForPreTraining` is a generic model class
        that will be instantiated as one of the model classes of the library -with the architecture used for pretraining this model– when created with the `AutoModelForPreTraining.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """

    def __init__(self):
        raise EnvironmentError(
            "AutoModelForPreTraining is designed to be instantiated "
            "using the `AutoModelForPreTraining.from_pretrained(pretrained_model_name_or_path)` or "
            "`AutoModelForPreTraining.from_config(config)` methods."
        )

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

        Args:
            config (:class:`~transformers.PretrainedConfig`):
                The model class to instantiate is selected based on the configuration class:

483
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
Iz Beltagy's avatar
Iz Beltagy committed
484
                - isInstance of `longformer` configuration class: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
485
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
thomwolf's avatar
thomwolf committed
486
487
                - isInstance of `bert` configuration class: :class:`~transformers.BertForPreTraining` (Bert model)
                - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
488
489
                - isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
                - isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL  model)
thomwolf's avatar
thomwolf committed
490
491
492
                - isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
Lysandre's avatar
Lysandre committed
493
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
Lysandre Debut's avatar
Lysandre Debut committed
494
                - isInstance of `electra` configuration class: :class:`~transformers.ElectraForPreTraining` (Electra model)
thomwolf's avatar
thomwolf committed
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelForPreTraining.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
        for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_PRETRAINING_MAPPING.keys())
            )
        )

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the model classes of the library -with the architecture used for pretraining this model– from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
            - contains `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model)
            - contains `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
            - contains `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
Iz Beltagy's avatar
Iz Beltagy committed
526
            - contains `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
thomwolf's avatar
thomwolf committed
527
528
529
530
531
532
533
534
            - contains `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
            - contains `bert`: :class:`~transformers.BertForPreTraining` (Bert model)
            - contains `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
            - contains `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
            - contains `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
            - contains `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
            - contains `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
            - contains `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
Lysandre's avatar
Lysandre committed
535
            - contains `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
Lysandre Debut's avatar
Lysandre Debut committed
536
            - contains `electra`: :class:`~transformers.ElectraForPreTraining` (Electra model)
thomwolf's avatar
thomwolf committed
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

        Args:
            pretrained_model_name_or_path:
                Either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
559
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
thomwolf's avatar
thomwolf committed
560
561
562
563
564
565
566
567
568
569
570
571
572
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely received file. Attempt to resume the download if such a file exists.
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.
            output_loading_info: (`optional`) boolean:
573
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
thomwolf's avatar
thomwolf committed
574
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
575
                These arguments will be passed to the configuration and the model.
thomwolf's avatar
thomwolf committed
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601

        Examples::

            model = AutoModelForPreTraining.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForPreTraining.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForPreTraining.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

        for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_PRETRAINING_MAPPING.keys())
            )
        )


Julien Chaumond's avatar
Julien Chaumond committed
602
class AutoModelWithLMHead:
603
    r"""
604
        :class:`~transformers.AutoModelWithLMHead` is a generic model class
605
606
607
608
609
610
        that will be instantiated as one of the language modeling model classes of the library
        when created with the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """
611

612
    def __init__(self):
613
614
        raise EnvironmentError(
            "AutoModelWithLMHead is designed to be instantiated "
615
            "using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` or "
616
617
            "`AutoModelWithLMHead.from_config(config)` methods."
        )
618
619
620
621
622
623

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

Lysandre's avatar
Lysandre committed
624
625
        Args:
            config (:class:`~transformers.PretrainedConfig`):
626
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
627

628
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
Iz Beltagy's avatar
Iz Beltagy committed
629
                - isInstance of `longformer` configuration class: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
630
631
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
                - isInstance of `bert` configuration class: :class:`~transformers.BertForMaskedLM` (Bert model)
Lysandre's avatar
Lysandre committed
632
                - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
633
634
                - isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
                - isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL  model)
Lysandre's avatar
Lysandre committed
635
636
637
                - isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
Lysandre's avatar
Lysandre committed
638
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
Lysandre Debut's avatar
Lysandre Debut committed
639
                - isInstance of `electra` configuration class: :class:`~transformers.ElectraForMaskedLM` (Electra model)
640
641
642
643
644
645

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelWithLMHead.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
646
647
648
649
650
651
652
653
654
        for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())
            )
        )
655
656
657
658
659
660
661

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the language modeling model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
662
663
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.
664
665
666

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
667
668
669
670
671
            - contains `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model)
            - contains `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
            - contains `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
Iz Beltagy's avatar
Iz Beltagy committed
672
            - contains `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
Lysandre's avatar
Lysandre committed
673
674
675
676
677
678
679
680
            - contains `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
            - contains `bert`: :class:`~transformers.BertForMaskedLM` (Bert model)
            - contains `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
            - contains `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
            - contains `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
            - contains `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
            - contains `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
            - contains `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
Lysandre's avatar
Lysandre committed
681
            - contains `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
Lysandre Debut's avatar
Lysandre Debut committed
682
            - contains `electra`: :class:`~transformers.ElectraForMaskedLM` (Electra model)
683
684
685
686

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
687
688
689
        Args:
            pretrained_model_name_or_path:
                Either:
thomwolf's avatar
thomwolf committed
690
691

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
692
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
693
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
694
695
696
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method
697
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
698
699
700
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
701
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
702
703
704
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
705
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
706
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
707
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
708
            cache_dir: (`optional`) string:
709
710
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
711
712
            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
713
            resume_download: (`optional`) boolean, default False:
Lysandre's avatar
Lysandre committed
714
                Do not delete incompletely received file. Attempt to resume the download if such a file exists.
thomwolf's avatar
thomwolf committed
715
716
717
718
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.
            output_loading_info: (`optional`) boolean:
719
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
thomwolf's avatar
thomwolf committed
720
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
721
                These arguments will be passed to the configuration and the model.
722
723
724
725
726
727
728
729
730
731
732

        Examples::

            model = AutoModelWithLMHead.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelWithLMHead.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
733
734
735
736
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

737
738
739
        for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
740
        raise ValueError(
741
742
743
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())
744
745
            )
        )
746
747


Julien Chaumond's avatar
Julien Chaumond committed
748
class AutoModelForSequenceClassification:
749
    r"""
750
        :class:`~transformers.AutoModelForSequenceClassification` is a generic model class
751
752
753
754
755
756
        that will be instantiated as one of the sequence classification model classes of the library
        when created with the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """
757

758
    def __init__(self):
759
760
        raise EnvironmentError(
            "AutoModelForSequenceClassification is designed to be instantiated "
761
            "using the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` or "
762
763
            "`AutoModelForSequenceClassification.from_config(config)` methods."
        )
764
765
766
767
768
769

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

Lysandre's avatar
Lysandre committed
770
771
        Args:
            config (:class:`~transformers.PretrainedConfig`):
772
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
773

774
775
776
777
778
779
780
781
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForSequenceClassification` (DistilBERT model)
                - isInstance of `albert` configuration class: :class:`~transformers.AlbertForSequenceClassification` (ALBERT model)
                - isInstance of `camembert` configuration class: :class:`~transformers.CamembertForSequenceClassification` (CamemBERT model)
                - isInstance of `xlm roberta` configuration class: :class:`~transformers.XLMRobertaForSequenceClassification` (XLM-RoBERTa model)
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaForSequenceClassification` (RoBERTa model)
                - isInstance of `bert` configuration class: :class:`~transformers.BertForSequenceClassification` (Bert model)
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetForSequenceClassification` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMForSequenceClassification` (XLM model)
Lysandre's avatar
Lysandre committed
782
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForSequenceClassification` (Flaubert model)
Lysandre's avatar
Lysandre committed
783

784
785
786
787
788
789

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelForSequenceClassification.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
790
791
792
793
794
795
796
797
798
799
800
        for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
            )
        )
801
802
803
804
805
806
807

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the sequence classification model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
808
809
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.
810
811
812

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
813
814
815
816
817
818
819
            - contains `distilbert`: :class:`~transformers.DistilBertForSequenceClassification` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertForSequenceClassification` (ALBERT model)
            - contains `camembert`: :class:`~transformers.CamembertForSequenceClassification` (CamemBERT model)
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaForSequenceClassification` (XLM-RoBERTa model)
            - contains `roberta`: :class:`~transformers.RobertaForSequenceClassification` (RoBERTa model)
            - contains `bert`: :class:`~transformers.BertForSequenceClassification` (Bert model)
            - contains `xlnet`: :class:`~transformers.XLNetForSequenceClassification` (XLNet model)
Lysandre's avatar
Lysandre committed
820
            - contains `flaubert`: :class:`~transformers.FlaubertForSequenceClassification` (Flaubert model)
821
822
823
824

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
825
        Args:
thomwolf's avatar
thomwolf committed
826
827
828
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
829
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
830
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
831
832
833
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
Lysandre's avatar
Lysandre committed
834
                All remaining positional arguments will be passed to the underlying model's ``__init__`` method
thomwolf's avatar
thomwolf committed
835

836
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
837
838
839
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
840
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
841
842
843
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
844
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
845
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
846
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
847
848

            cache_dir: (`optional`) string:
849
850
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
851
852
853
854

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

855
856
857
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

thomwolf's avatar
thomwolf committed
858
859
860
861
862
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
863
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
thomwolf's avatar
thomwolf committed
864
865

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
866
                These arguments will be passed to the configuration and the model.
867
868
869
870
871
872
873
874
875
876
877

        Examples::

            model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForSequenceClassification.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
878
879
880
881
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

882
883
884
        for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
885
        raise ValueError(
886
887
888
889
890
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
891
892
            )
        )
893
894


Julien Chaumond's avatar
Julien Chaumond committed
895
class AutoModelForQuestionAnswering:
896
    r"""
897
        :class:`~transformers.AutoModelForQuestionAnswering` is a generic model class
898
899
900
901
902
903
        that will be instantiated as one of the question answering model classes of the library
        when created with the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """
904

905
    def __init__(self):
906
907
        raise EnvironmentError(
            "AutoModelForQuestionAnswering is designed to be instantiated "
908
            "using the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or "
909
910
            "`AutoModelForQuestionAnswering.from_config(config)` methods."
        )
911
912
913
914
915
916

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

Lysandre's avatar
Lysandre committed
917
918
        Args:
            config (:class:`~transformers.PretrainedConfig`):
919
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
920

921
922
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForQuestionAnswering` (DistilBERT model)
                - isInstance of `albert` configuration class: :class:`~transformers.AlbertForQuestionAnswering` (ALBERT model)
Lysandre's avatar
Lysandre committed
923
                - isInstance of `bert` configuration class: :class:`~transformers.BertModelForQuestionAnswering` (Bert model)
924
925
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetForQuestionAnswering` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMForQuestionAnswering` (XLM model)
Lysandre's avatar
Lysandre committed
926
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForQuestionAnswering` (XLM model)
927
928
929
930

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
flozi00's avatar
flozi00 committed
931
            model = AutoModelForQuestionAnswering.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
932
        """
933
934
935
936
937
938
939
940
941
942
943
944
        for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)

        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
            )
        )
945
946
947
948
949
950
951

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the question answering model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
952
953
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.
954
955
956

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
957
958
959
960
961
            - contains `distilbert`: :class:`~transformers.DistilBertForQuestionAnswering` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertForQuestionAnswering` (ALBERT model)
            - contains `bert`: :class:`~transformers.BertForQuestionAnswering` (Bert model)
            - contains `xlnet`: :class:`~transformers.XLNetForQuestionAnswering` (XLNet model)
            - contains `xlm`: :class:`~transformers.XLMForQuestionAnswering` (XLM model)
Lysandre's avatar
Lysandre committed
962
            - contains `flaubert`: :class:`~transformers.FlaubertForQuestionAnswering` (XLM model)
963
964
965
966

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
967
        Args:
thomwolf's avatar
thomwolf committed
968
969
970
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
971
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
972
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
973
974
975
976
977
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

978
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
979
980
981
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
982
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
983
984
985
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
986
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
987
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
988
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
989
990

            cache_dir: (`optional`) string:
991
992
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
993
994
995
996
997
998
999
1000
1001

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
1002
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
thomwolf's avatar
thomwolf committed
1003
1004

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
1005
                These arguments will be passed to the configuration and the model.
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016

        Examples::

            model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForQuestionAnswering.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
1017
1018
1019
1020
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

1021
1022
1023
        for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
1024

1025
        raise ValueError(
1026
1027
1028
1029
1030
1031
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
            )
1032
        )
1033
1034
1035


class AutoModelForTokenClassification:
Lysandre's avatar
Lysandre committed
1036
1037
1038
1039
1040
1041
1042
1043
1044
    r"""
        :class:`~transformers.AutoModelForTokenClassification` is a generic model class
        that will be instantiated as one of the token classification model classes of the library
        when created with the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """

1045
    def __init__(self):
1046
1047
1048
1049
1050
        raise EnvironmentError(
            "AutoModelForTokenClassification is designed to be instantiated "
            "using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or "
            "`AutoModelForTokenClassification.from_config(config)` methods."
        )
1051
1052
1053
1054
1055

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.
1056

Lysandre's avatar
Lysandre committed
1057
1058
        Args:
            config (:class:`~transformers.PretrainedConfig`):
1059
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
1060

Lysandre's avatar
Lysandre committed
1061
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModelForTokenClassification` (DistilBERT model)
1062
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMForTokenClassification` (XLM model)
Lysandre's avatar
Lysandre committed
1063
1064
                - isInstance of `xlm roberta` configuration class: :class:`~transformers.XLMRobertaModelForTokenClassification` (XLMRoberta model)
                - isInstance of `bert` configuration class: :class:`~transformers.BertModelForTokenClassification` (Bert model)
1065
                - isInstance of `albert` configuration class: :class:`~transformers.AlbertForTokenClassification` (AlBert model)
Lysandre's avatar
Lysandre committed
1066
1067
1068
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModelForTokenClassification` (XLNet model)
                - isInstance of `camembert` configuration class: :class:`~transformers.CamembertModelForTokenClassification` (Camembert model)
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModelForTokenClassification` (Roberta model)
Lysandre Debut's avatar
Lysandre Debut committed
1069
                - isInstance of `electra` configuration class: :class:`~transformers.ElectraForTokenClassification` (Electra model)
1070

1071
        Examples::
1072

1073
1074
1075
            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelForTokenClassification.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
        for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)

        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
            )
        )
1088

1089
1090
1091
1092
1093
1094
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the question answering model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
1095
1096
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.
1097
1098
1099

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
1100
            - contains `distilbert`: :class:`~transformers.DistilBertForTokenClassification` (DistilBERT model)
1101
            - contains `xlm`: :class:`~transformers.XLMForTokenClassification` (XLM model)
Lysandre's avatar
Lysandre committed
1102
1103
1104
1105
1106
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaForTokenClassification` (XLM-RoBERTa?Para model)
            - contains `camembert`: :class:`~transformers.CamembertForTokenClassification` (Camembert model)
            - contains `bert`: :class:`~transformers.BertForTokenClassification` (Bert model)
            - contains `xlnet`: :class:`~transformers.XLNetForTokenClassification` (XLNet model)
            - contains `roberta`: :class:`~transformers.RobertaForTokenClassification` (Roberta model)
Lysandre Debut's avatar
Lysandre Debut committed
1107
            - contains `electra`: :class:`~transformers.ElectraForTokenClassification` (Electra model)
1108
1109
1110
1111

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
1112
1113
1114
        Args:
            pretrained_model_name_or_path:
                Either:
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
1131
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
1147
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
1148
1149

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
1150
                These arguments will be passed to the configuration and the model.
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161

        Examples::

            model = AutoModelForTokenClassification.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForTokenClassification.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
1162
1163
1164
1165
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

1166
1167
1168
        for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
1169

1170
        raise ValueError(
1171
1172
1173
1174
1175
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
1176
1177
            )
        )
Julien Chaumond's avatar
Julien Chaumond committed
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229


class AutoModelForMultipleChoice:
    r"""
        :class:`~transformers.AutoModelForMultipleChoice` is a generic model class
        that will be instantiated as one of the multiple choice model classes of the library
        when created with the `AutoModelForMultipleChoice.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """

    def __init__(self):
        raise EnvironmentError(
            "AutoModelForMultipleChoice is designed to be instantiated "
            "using the `AutoModelForMultipleChoice.from_pretrained(pretrained_model_name_or_path)` or "
            "`AutoModelForMultipleChoice.from_config(config)` methods."
        )

    @classmethod
    def from_config(cls, config):
        for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)

        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
            )
        )

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

        for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)

        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
            )
        )