modeling_auto.py 64.6 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """


import logging
Julien Chaumond's avatar
Julien Chaumond committed
19
from collections import OrderedDict
thomwolf's avatar
thomwolf committed
20

21
22
from .configuration_auto import (
    AlbertConfig,
23
    AutoConfig,
Sam Shleifer's avatar
Sam Shleifer committed
24
    BartConfig,
25
26
27
28
    BertConfig,
    CamembertConfig,
    CTRLConfig,
    DistilBertConfig,
Lysandre's avatar
Lysandre committed
29
    FlaubertConfig,
30
31
32
    GPT2Config,
    OpenAIGPTConfig,
    RobertaConfig,
33
    T5Config,
34
35
36
    TransfoXLConfig,
    XLMConfig,
    XLMRobertaConfig,
Aymeric Augustin's avatar
Aymeric Augustin committed
37
38
    XLNetConfig,
)
39
from .configuration_utils import PretrainedConfig
Aymeric Augustin's avatar
Aymeric Augustin committed
40
41
42
43
44
from .modeling_albert import (
    ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
    AlbertForMaskedLM,
    AlbertForQuestionAnswering,
    AlbertForSequenceClassification,
45
    AlbertForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
46
    AlbertModel,
47
)
Sam Shleifer's avatar
Sam Shleifer committed
48
from .modeling_bart import BART_PRETRAINED_MODEL_ARCHIVE_MAP, BartForMaskedLM, BartForSequenceClassification, BartModel
49
from .modeling_bert import (
Aymeric Augustin's avatar
Aymeric Augustin committed
50
    BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
51
    BertForMaskedLM,
thomwolf's avatar
thomwolf committed
52
    BertForPreTraining,
53
    BertForQuestionAnswering,
Aymeric Augustin's avatar
Aymeric Augustin committed
54
    BertForSequenceClassification,
55
    BertForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
56
    BertModel,
57
)
Aymeric Augustin's avatar
Aymeric Augustin committed
58
59
60
61
62
63
from .modeling_camembert import (
    CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
    CamembertForMaskedLM,
    CamembertForSequenceClassification,
    CamembertForTokenClassification,
    CamembertModel,
64
)
Aymeric Augustin's avatar
Aymeric Augustin committed
65
from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel, CTRLModel
66
from .modeling_distilbert import (
Aymeric Augustin's avatar
Aymeric Augustin committed
67
    DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
68
    DistilBertForMaskedLM,
Aymeric Augustin's avatar
Aymeric Augustin committed
69
    DistilBertForQuestionAnswering,
70
71
    DistilBertForSequenceClassification,
    DistilBertForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
72
    DistilBertModel,
73
)
Lysandre's avatar
Lysandre committed
74
75
76
77
78
79
80
from .modeling_flaubert import (
    FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
    FlaubertForQuestionAnswering,
    FlaubertForSequenceClassification,
    FlaubertModel,
    FlaubertWithLMHeadModel,
)
Aymeric Augustin's avatar
Aymeric Augustin committed
81
82
83
84
85
from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model
from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_roberta import (
    ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
    RobertaForMaskedLM,
Julien Chaumond's avatar
Julien Chaumond committed
86
    RobertaForQuestionAnswering,
Aymeric Augustin's avatar
Aymeric Augustin committed
87
88
89
    RobertaForSequenceClassification,
    RobertaForTokenClassification,
    RobertaModel,
90
)
Aymeric Augustin's avatar
Aymeric Augustin committed
91
92
93
94
95
96
97
98
from .modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5Model, T5WithLMHeadModel
from .modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TransfoXLModel
from .modeling_xlm import (
    XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
    XLMForQuestionAnswering,
    XLMForSequenceClassification,
    XLMModel,
    XLMWithLMHeadModel,
99
100
)
from .modeling_xlm_roberta import (
Aymeric Augustin's avatar
Aymeric Augustin committed
101
    XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
102
    XLMRobertaForMaskedLM,
Aymeric Augustin's avatar
Aymeric Augustin committed
103
    XLMRobertaForSequenceClassification,
104
    XLMRobertaForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
105
106
107
108
109
110
111
112
113
    XLMRobertaModel,
)
from .modeling_xlnet import (
    XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
    XLNetForQuestionAnswering,
    XLNetForSequenceClassification,
    XLNetForTokenClassification,
    XLNetLMHeadModel,
    XLNetModel,
114
)
thomwolf's avatar
thomwolf committed
115

thomwolf's avatar
thomwolf committed
116

117
logger = logging.getLogger(__name__)
thomwolf's avatar
thomwolf committed
118
119


120
121
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
    (key, value)
122
123
    for pretrained_map in [
        BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
Sam Shleifer's avatar
Sam Shleifer committed
124
        BART_PRETRAINED_MODEL_ARCHIVE_MAP,
125
126
127
128
129
130
131
132
133
134
135
        OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
        TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
        GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
        CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
        XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
        XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
        ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
        DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
        ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
        CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
        T5_PRETRAINED_MODEL_ARCHIVE_MAP,
Lysandre's avatar
Lysandre committed
136
        FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
Lysandre's avatar
Lysandre committed
137
        XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
138
139
140
    ]
    for key, value, in pretrained_map.items()
)
141

Julien Chaumond's avatar
Julien Chaumond committed
142
MODEL_MAPPING = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
143
144
145
146
147
    [
        (T5Config, T5Model),
        (DistilBertConfig, DistilBertModel),
        (AlbertConfig, AlbertModel),
        (CamembertConfig, CamembertModel),
148
        (XLMRobertaConfig, XLMRobertaModel),
Sam Shleifer's avatar
Sam Shleifer committed
149
        (BartConfig, BartModel),
150
        (RobertaConfig, RobertaModel),
Julien Chaumond's avatar
Julien Chaumond committed
151
152
153
154
155
        (BertConfig, BertModel),
        (OpenAIGPTConfig, OpenAIGPTModel),
        (GPT2Config, GPT2Model),
        (TransfoXLConfig, TransfoXLModel),
        (XLNetConfig, XLNetModel),
Lysandre's avatar
Lysandre committed
156
        (FlaubertConfig, FlaubertModel),
Julien Chaumond's avatar
Julien Chaumond committed
157
158
159
160
161
        (XLMConfig, XLMModel),
        (CTRLConfig, CTRLModel),
    ]
)

thomwolf's avatar
thomwolf committed
162
163
164
165
166
167
168
MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
    [
        (T5Config, T5WithLMHeadModel),
        (DistilBertConfig, DistilBertForMaskedLM),
        (AlbertConfig, AlbertForMaskedLM),
        (CamembertConfig, CamembertForMaskedLM),
        (XLMRobertaConfig, XLMRobertaForMaskedLM),
Sam Shleifer's avatar
Sam Shleifer committed
169
        (BartConfig, BartForMaskedLM),
thomwolf's avatar
thomwolf committed
170
171
172
173
174
175
        (RobertaConfig, RobertaForMaskedLM),
        (BertConfig, BertForPreTraining),
        (OpenAIGPTConfig, OpenAIGPTLMHeadModel),
        (GPT2Config, GPT2LMHeadModel),
        (TransfoXLConfig, TransfoXLLMHeadModel),
        (XLNetConfig, XLNetLMHeadModel),
Lysandre's avatar
Lysandre committed
176
        (FlaubertConfig, FlaubertWithLMHeadModel),
thomwolf's avatar
thomwolf committed
177
178
179
180
181
        (XLMConfig, XLMWithLMHeadModel),
        (CTRLConfig, CTRLLMHeadModel),
    ]
)

Julien Chaumond's avatar
Julien Chaumond committed
182
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
183
184
185
186
187
188
    [
        (T5Config, T5WithLMHeadModel),
        (DistilBertConfig, DistilBertForMaskedLM),
        (AlbertConfig, AlbertForMaskedLM),
        (CamembertConfig, CamembertForMaskedLM),
        (XLMRobertaConfig, XLMRobertaForMaskedLM),
Sam Shleifer's avatar
Sam Shleifer committed
189
        (BartConfig, BartForMaskedLM),
190
        (RobertaConfig, RobertaForMaskedLM),
191
192
193
194
195
        (BertConfig, BertForMaskedLM),
        (OpenAIGPTConfig, OpenAIGPTLMHeadModel),
        (GPT2Config, GPT2LMHeadModel),
        (TransfoXLConfig, TransfoXLLMHeadModel),
        (XLNetConfig, XLNetLMHeadModel),
Lysandre's avatar
Lysandre committed
196
        (FlaubertConfig, FlaubertWithLMHeadModel),
197
198
199
200
201
        (XLMConfig, XLMWithLMHeadModel),
        (CTRLConfig, CTRLLMHeadModel),
    ]
)

Julien Chaumond's avatar
Julien Chaumond committed
202
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
203
204
205
206
207
    [
        (DistilBertConfig, DistilBertForSequenceClassification),
        (AlbertConfig, AlbertForSequenceClassification),
        (CamembertConfig, CamembertForSequenceClassification),
        (XLMRobertaConfig, XLMRobertaForSequenceClassification),
Sam Shleifer's avatar
Sam Shleifer committed
208
        (BartConfig, BartForSequenceClassification),
209
        (RobertaConfig, RobertaForSequenceClassification),
210
211
        (BertConfig, BertForSequenceClassification),
        (XLNetConfig, XLNetForSequenceClassification),
Lysandre's avatar
Lysandre committed
212
        (FlaubertConfig, FlaubertForSequenceClassification),
Lysandre's avatar
Lysandre committed
213
        (XLMConfig, XLMForSequenceClassification),
214
215
216
    ]
)

Julien Chaumond's avatar
Julien Chaumond committed
217
MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
218
219
220
    [
        (DistilBertConfig, DistilBertForQuestionAnswering),
        (AlbertConfig, AlbertForQuestionAnswering),
Malte Pietsch's avatar
Malte Pietsch committed
221
        (RobertaConfig, RobertaForQuestionAnswering),
222
223
        (BertConfig, BertForQuestionAnswering),
        (XLNetConfig, XLNetForQuestionAnswering),
Lysandre's avatar
Lysandre committed
224
        (FlaubertConfig, FlaubertForQuestionAnswering),
Lysandre's avatar
Lysandre committed
225
        (XLMConfig, XLMForQuestionAnswering),
226
227
228
    ]
)

Julien Chaumond's avatar
Julien Chaumond committed
229
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
230
231
232
    [
        (DistilBertConfig, DistilBertForTokenClassification),
        (CamembertConfig, CamembertForTokenClassification),
233
        (XLMRobertaConfig, XLMRobertaForTokenClassification),
234
        (RobertaConfig, RobertaForTokenClassification),
Julien Chaumond's avatar
Julien Chaumond committed
235
236
        (BertConfig, BertForTokenClassification),
        (XLNetConfig, XLNetForTokenClassification),
237
        (AlbertConfig, AlbertForTokenClassification),
Julien Chaumond's avatar
Julien Chaumond committed
238
239
240
    ]
)

241

thomwolf's avatar
thomwolf committed
242
243
class AutoModel(object):
    r"""
244
        :class:`~transformers.AutoModel` is a generic model class
thomwolf's avatar
thomwolf committed
245
246
        that will be instantiated as one of the base model classes of the library
        when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
247
        or the `AutoModel.from_config(config)` class methods.
thomwolf's avatar
thomwolf committed
248

249
        This class cannot be instantiated using `__init__()` (throws an error).
thomwolf's avatar
thomwolf committed
250
    """
251

thomwolf's avatar
thomwolf committed
252
    def __init__(self):
253
254
        raise EnvironmentError(
            "AutoModel is designed to be instantiated "
255
            "using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` or "
256
257
            "`AutoModel.from_config(config)` methods."
        )
258
259
260
261
262
263

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

Lysandre's avatar
Lysandre committed
264
265
        Args:
            config (:class:`~transformers.PretrainedConfig`):
266
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
267
268
269
270
271
272
273
274
275
276

                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModel` (DistilBERT model)
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModel` (RoBERTa model)
                - isInstance of `bert` configuration class: :class:`~transformers.BertModel` (Bert model)
                - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model)
                - isInstance of `gpt2` configuration class: :class:`~transformers.GPT2Model` (OpenAI GPT-2 model)
                - isInstance of `ctrl` configuration class: :class:`~transformers.CTRLModel` (Salesforce CTRL  model)
                - isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLModel` (Transformer-XL model)
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModel` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMModel` (XLM model)
Lysandre's avatar
Lysandre committed
277
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertModel` (XLM model)
278
279
280
281
282
283

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModel.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
Julien Chaumond's avatar
Julien Chaumond committed
284
285
286
        for config_class, model_class in MODEL_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)
287
288
289
290
291
292
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
            )
        )
thomwolf's avatar
thomwolf committed
293
294
295

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
296
        r""" Instantiates one of the base model classes of the library
thomwolf's avatar
thomwolf committed
297
298
        from a pre-trained model configuration.

Lysandre's avatar
Lysandre committed
299
300
301
302
303
        The `from_pretrained()` method takes care of returning the correct model class instance
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.

        The base model class to instantiate is selected as the first pattern matching
thomwolf's avatar
thomwolf committed
304
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
305
306
307
308
309
310
311
312
313
314
315
316
317
            - contains `t5`: :class:`~transformers.T5Model` (T5 model)
            - contains `distilbert`: :class:`~transformers.DistilBertModel` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertModel` (ALBERT model)
            - contains `camembert`: :class:`~transformers.CamembertModel` (CamemBERT model)
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaModel` (XLM-RoBERTa model)
            - contains `roberta`: :class:`~transformers.RobertaModel` (RoBERTa model)
            - contains `bert`: :class:`~transformers.BertModel` (Bert model)
            - contains `openai-gpt`: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model)
            - contains `gpt2`: :class:`~transformers.GPT2Model` (OpenAI GPT-2 model)
            - contains `transfo-xl`: :class:`~transformers.TransfoXLModel` (Transformer-XL model)
            - contains `xlnet`: :class:`~transformers.XLNetModel` (XLNet model)
            - contains `xlm`: :class:`~transformers.XLMModel` (XLM model)
            - contains `ctrl`: :class:`~transformers.CTRLModel` (Salesforce CTRL  model)
Lysandre's avatar
Lysandre committed
318
            - contains `flaubert`: :class:`~transformers.Flaubert` (Flaubert  model)
thomwolf's avatar
thomwolf committed
319

thomwolf's avatar
typos  
thomwolf committed
320
            The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
thomwolf's avatar
thomwolf committed
321
322
            To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
323
        Args:
thomwolf's avatar
thomwolf committed
324
325
326
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
327
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
328
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
329
330
331
332
333
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

334
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
335
336
337
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
338
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
339
340
341
342
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
thomwolf's avatar
typos  
thomwolf committed
343
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
344
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
345
346

            cache_dir: (`optional`) string:
thomwolf's avatar
thomwolf committed
347
348
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
349
350
351
352

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

353
354
355
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

thomwolf's avatar
thomwolf committed
356
357
358
359
360
361
362
363
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
364
                These arguments will be passed to the configuration and the model.
thomwolf's avatar
thomwolf committed
365
366
367

        Examples::

thomwolf's avatar
thomwolf committed
368
369
370
371
372
373
            model = AutoModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModel.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
thomwolf's avatar
thomwolf committed
374
375

        """
376
377
378
379
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

Julien Chaumond's avatar
Julien Chaumond committed
380
381
382
        for config_class, model_class in MODEL_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
383
        raise ValueError(
384
385
386
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
387
388
            )
        )
389
390


thomwolf's avatar
thomwolf committed
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
class AutoModelForPreTraining(object):
    r"""
        :class:`~transformers.AutoModelForPreTraining` is a generic model class
        that will be instantiated as one of the model classes of the library -with the architecture used for pretraining this model– when created with the `AutoModelForPreTraining.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """

    def __init__(self):
        raise EnvironmentError(
            "AutoModelForPreTraining is designed to be instantiated "
            "using the `AutoModelForPreTraining.from_pretrained(pretrained_model_name_or_path)` or "
            "`AutoModelForPreTraining.from_config(config)` methods."
        )

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

        Args:
            config (:class:`~transformers.PretrainedConfig`):
                The model class to instantiate is selected based on the configuration class:

                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModelForMaskedLM` (DistilBERT model)
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModelForMaskedLM` (RoBERTa model)
                - isInstance of `bert` configuration class: :class:`~transformers.BertForPreTraining` (Bert model)
                - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
                - isInstance of `gpt2` configuration class: :class:`~transformers.GPT2ModelLMHeadModel` (OpenAI GPT-2 model)
                - isInstance of `ctrl` configuration class: :class:`~transformers.CTRLModelLMHeadModel` (Salesforce CTRL  model)
                - isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
Lysandre's avatar
Lysandre committed
425
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
thomwolf's avatar
thomwolf committed
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelForPreTraining.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
        for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_PRETRAINING_MAPPING.keys())
            )
        )

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the model classes of the library -with the architecture used for pretraining this model– from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
            - contains `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model)
            - contains `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
            - contains `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
            - contains `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
            - contains `bert`: :class:`~transformers.BertForPreTraining` (Bert model)
            - contains `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
            - contains `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
            - contains `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
            - contains `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
            - contains `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
            - contains `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
Lysandre's avatar
Lysandre committed
465
            - contains `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
thomwolf's avatar
thomwolf committed
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

        Args:
            pretrained_model_name_or_path:
                Either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely received file. Attempt to resume the download if such a file exists.
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.
            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
504
                These arguments will be passed to the configuration and the model.
thomwolf's avatar
thomwolf committed
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

        Examples::

            model = AutoModelForPreTraining.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForPreTraining.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForPreTraining.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

        for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_PRETRAINING_MAPPING.keys())
            )
        )


531
532
class AutoModelWithLMHead(object):
    r"""
533
        :class:`~transformers.AutoModelWithLMHead` is a generic model class
534
535
536
537
538
539
        that will be instantiated as one of the language modeling model classes of the library
        when created with the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """
540

541
    def __init__(self):
542
543
        raise EnvironmentError(
            "AutoModelWithLMHead is designed to be instantiated "
544
            "using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` or "
545
546
            "`AutoModelWithLMHead.from_config(config)` methods."
        )
547
548
549
550
551
552

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

Lysandre's avatar
Lysandre committed
553
554
        Args:
            config (:class:`~transformers.PretrainedConfig`):
555
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
556

Lysandre's avatar
Lysandre committed
557
558
559
560
561
562
563
564
565
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModelForMaskedLM` (DistilBERT model)
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModelForMaskedLM` (RoBERTa model)
                - isInstance of `bert` configuration class: :class:`~transformers.BertModelForMaskedLM` (Bert model)
                - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
                - isInstance of `gpt2` configuration class: :class:`~transformers.GPT2ModelLMHeadModel` (OpenAI GPT-2 model)
                - isInstance of `ctrl` configuration class: :class:`~transformers.CTRLModelLMHeadModel` (Salesforce CTRL  model)
                - isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
Lysandre's avatar
Lysandre committed
566
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
567
568
569
570
571
572

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelWithLMHead.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
573
574
575
576
577
578
579
580
581
        for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())
            )
        )
582
583
584
585
586
587
588

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the language modeling model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
589
590
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.
591
592
593

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
594
595
596
597
598
599
600
601
602
603
604
605
606
            - contains `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model)
            - contains `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
            - contains `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
            - contains `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
            - contains `bert`: :class:`~transformers.BertForMaskedLM` (Bert model)
            - contains `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
            - contains `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
            - contains `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
            - contains `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
            - contains `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
            - contains `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
Lysandre's avatar
Lysandre committed
607
            - contains `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
608
609
610
611

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
612
613
614
        Args:
            pretrained_model_name_or_path:
                Either:
thomwolf's avatar
thomwolf committed
615
616

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
617
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
618
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
619
620
621
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method
622
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
623
624
625
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
626
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
627
628
629
630
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
631
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
632
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
633
            cache_dir: (`optional`) string:
634
635
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
636
637
            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
638
            resume_download: (`optional`) boolean, default False:
Lysandre's avatar
Lysandre committed
639
                Do not delete incompletely received file. Attempt to resume the download if such a file exists.
thomwolf's avatar
thomwolf committed
640
641
642
643
644
645
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.
            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
646
                These arguments will be passed to the configuration and the model.
647
648
649
650
651
652
653
654
655
656
657

        Examples::

            model = AutoModelWithLMHead.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelWithLMHead.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
658
659
660
661
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

662
663
664
        for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
665
        raise ValueError(
666
667
668
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())
669
670
            )
        )
671
672
673
674


class AutoModelForSequenceClassification(object):
    r"""
675
        :class:`~transformers.AutoModelForSequenceClassification` is a generic model class
676
677
678
679
680
681
        that will be instantiated as one of the sequence classification model classes of the library
        when created with the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """
682

683
    def __init__(self):
684
685
        raise EnvironmentError(
            "AutoModelForSequenceClassification is designed to be instantiated "
686
            "using the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` or "
687
688
            "`AutoModelForSequenceClassification.from_config(config)` methods."
        )
689
690
691
692
693
694

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

Lysandre's avatar
Lysandre committed
695
696
        Args:
            config (:class:`~transformers.PretrainedConfig`):
697
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
698

Lysandre's avatar
Lysandre committed
699
700
701
702
703
704
705
706
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModelForSequenceClassification` (DistilBERT model)
                - isInstance of `albert` configuration class: :class:`~transformers.AlbertModelForSequenceClassification` (ALBERT model)
                - isInstance of `camembert` configuration class: :class:`~transformers.CamembertModelForSequenceClassification` (CamemBERT model)
                - isInstance of `xlm roberta` configuration class: :class:`~transformers.XLMRobertaModelForSequenceClassification` (XLM-RoBERTa model)
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModelForSequenceClassification` (RoBERTa model)
                - isInstance of `bert` configuration class: :class:`~transformers.BertModelForSequenceClassification` (Bert model)
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModelForSequenceClassification` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMModelForSequenceClassification` (XLM model)
Lysandre's avatar
Lysandre committed
707
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForSequenceClassification` (Flaubert model)
Lysandre's avatar
Lysandre committed
708

709
710
711
712
713
714

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelForSequenceClassification.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
715
716
717
718
719
720
721
722
723
724
725
        for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)
        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
            )
        )
726
727
728
729
730
731
732

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the sequence classification model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
733
734
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.
735
736
737

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
738
739
740
741
742
743
744
            - contains `distilbert`: :class:`~transformers.DistilBertForSequenceClassification` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertForSequenceClassification` (ALBERT model)
            - contains `camembert`: :class:`~transformers.CamembertForSequenceClassification` (CamemBERT model)
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaForSequenceClassification` (XLM-RoBERTa model)
            - contains `roberta`: :class:`~transformers.RobertaForSequenceClassification` (RoBERTa model)
            - contains `bert`: :class:`~transformers.BertForSequenceClassification` (Bert model)
            - contains `xlnet`: :class:`~transformers.XLNetForSequenceClassification` (XLNet model)
Lysandre's avatar
Lysandre committed
745
            - contains `flaubert`: :class:`~transformers.FlaubertForSequenceClassification` (Flaubert model)
746
747
748
749

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
750
        Args:
thomwolf's avatar
thomwolf committed
751
752
753
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
754
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
755
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
756
757
758
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
Lysandre's avatar
Lysandre committed
759
                All remaining positional arguments will be passed to the underlying model's ``__init__`` method
thomwolf's avatar
thomwolf committed
760

761
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
762
763
764
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
765
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
766
767
768
769
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
770
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
771
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
772
773

            cache_dir: (`optional`) string:
774
775
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
776
777
778
779

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

780
781
782
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

thomwolf's avatar
thomwolf committed
783
784
785
786
787
788
789
790
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
791
                These arguments will be passed to the configuration and the model.
792
793
794
795
796
797
798
799
800
801
802

        Examples::

            model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForSequenceClassification.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
803
804
805
806
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

807
808
809
        for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
810
        raise ValueError(
811
812
813
814
815
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
816
817
            )
        )
818
819
820
821


class AutoModelForQuestionAnswering(object):
    r"""
822
        :class:`~transformers.AutoModelForQuestionAnswering` is a generic model class
823
824
825
826
827
828
        that will be instantiated as one of the question answering model classes of the library
        when created with the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """
829

830
    def __init__(self):
831
832
        raise EnvironmentError(
            "AutoModelForQuestionAnswering is designed to be instantiated "
833
            "using the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or "
834
835
            "`AutoModelForQuestionAnswering.from_config(config)` methods."
        )
836
837
838
839
840
841

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

Lysandre's avatar
Lysandre committed
842
843
        Args:
            config (:class:`~transformers.PretrainedConfig`):
844
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
845

Lysandre's avatar
Lysandre committed
846
847
848
849
850
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModelForQuestionAnswering` (DistilBERT model)
                - isInstance of `albert` configuration class: :class:`~transformers.AlbertModelForQuestionAnswering` (ALBERT model)
                - isInstance of `bert` configuration class: :class:`~transformers.BertModelForQuestionAnswering` (Bert model)
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModelForQuestionAnswering` (XLNet model)
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMModelForQuestionAnswering` (XLM model)
Lysandre's avatar
Lysandre committed
851
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForQuestionAnswering` (XLM model)
852
853
854
855
856
857

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelForSequenceClassification.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
858
859
860
861
862
863
864
865
866
867
868
869
        for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)

        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
            )
        )
870
871
872
873
874
875
876

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the question answering model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
877
878
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.
879
880
881

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
882
883
884
885
886
            - contains `distilbert`: :class:`~transformers.DistilBertForQuestionAnswering` (DistilBERT model)
            - contains `albert`: :class:`~transformers.AlbertForQuestionAnswering` (ALBERT model)
            - contains `bert`: :class:`~transformers.BertForQuestionAnswering` (Bert model)
            - contains `xlnet`: :class:`~transformers.XLNetForQuestionAnswering` (XLNet model)
            - contains `xlm`: :class:`~transformers.XLMForQuestionAnswering` (XLM model)
Lysandre's avatar
Lysandre committed
887
            - contains `flaubert`: :class:`~transformers.FlaubertForQuestionAnswering` (XLM model)
888
889
890
891

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
892
        Args:
thomwolf's avatar
thomwolf committed
893
894
895
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
896
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
897
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
898
899
900
901
902
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

903
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
904
905
906
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
907
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
908
909
910
911
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
912
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
913
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
914
915

            cache_dir: (`optional`) string:
916
917
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
918
919
920
921
922
923
924
925
926
927
928
929

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
930
                These arguments will be passed to the configuration and the model.
931
932
933
934
935
936
937
938
939
940
941

        Examples::

            model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForQuestionAnswering.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
942
943
944
945
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

946
947
948
        for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
949

950
        raise ValueError(
951
952
953
954
955
956
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
            )
957
        )
958
959
960


class AutoModelForTokenClassification:
Lysandre's avatar
Lysandre committed
961
962
963
964
965
966
967
968
969
    r"""
        :class:`~transformers.AutoModelForTokenClassification` is a generic model class
        that will be instantiated as one of the token classification model classes of the library
        when created with the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)`
        class method.

        This class cannot be instantiated using `__init__()` (throws an error).
    """

970
    def __init__(self):
971
972
973
974
975
        raise EnvironmentError(
            "AutoModelForTokenClassification is designed to be instantiated "
            "using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or "
            "`AutoModelForTokenClassification.from_config(config)` methods."
        )
976
977
978
979
980

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.
981

Lysandre's avatar
Lysandre committed
982
983
        Args:
            config (:class:`~transformers.PretrainedConfig`):
984
                The model class to instantiate is selected based on the configuration class:
Lysandre's avatar
Lysandre committed
985

Lysandre's avatar
Lysandre committed
986
987
988
989
990
991
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModelForTokenClassification` (DistilBERT model)
                - isInstance of `xlm roberta` configuration class: :class:`~transformers.XLMRobertaModelForTokenClassification` (XLMRoberta model)
                - isInstance of `bert` configuration class: :class:`~transformers.BertModelForTokenClassification` (Bert model)
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModelForTokenClassification` (XLNet model)
                - isInstance of `camembert` configuration class: :class:`~transformers.CamembertModelForTokenClassification` (Camembert model)
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModelForTokenClassification` (Roberta model)
992

993
        Examples::
994

995
996
997
            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelForTokenClassification.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
        for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
            if isinstance(config, config_class):
                return model_class(config)

        raise ValueError(
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
            )
        )
1010

1011
1012
1013
1014
1015
1016
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the question answering model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
1017
1018
        based on the `model_type` property of the config object, or when it's missing,
        falling back to using pattern matching on the `pretrained_model_name_or_path` string.
1019
1020
1021

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
Lysandre's avatar
Lysandre committed
1022
1023
1024
1025
1026
1027
            - contains `distilbert`: :class:`~transformers.DistilBertForTokenClassification` (DistilBERT model)
            - contains `xlm-roberta`: :class:`~transformers.XLMRobertaForTokenClassification` (XLM-RoBERTa?Para model)
            - contains `camembert`: :class:`~transformers.CamembertForTokenClassification` (Camembert model)
            - contains `bert`: :class:`~transformers.BertForTokenClassification` (Bert model)
            - contains `xlnet`: :class:`~transformers.XLNetForTokenClassification` (XLNet model)
            - contains `roberta`: :class:`~transformers.RobertaForTokenClassification` (Roberta model)
1028
1029
1030
1031

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

Lysandre's avatar
Lysandre committed
1032
1033
1034
        Args:
            pretrained_model_name_or_path:
                Either:
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
1070
                These arguments will be passed to the configuration and the model.
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081

        Examples::

            model = AutoModelForTokenClassification.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForTokenClassification.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
1082
1083
1084
1085
        config = kwargs.pop("config", None)
        if not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

1086
1087
1088
        for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
            if isinstance(config, config_class):
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
1089

1090
        raise ValueError(
1091
1092
1093
1094
1095
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
            "Model type should be one of {}.".format(
                config.__class__,
                cls.__name__,
                ", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
1096
1097
            )
        )