modeling_auto.py 61.9 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """


import logging

20
21
22
23
24
25
26
27
28
29
30
31
from .configuration_auto import (
    AlbertConfig,
    BertConfig,
    CamembertConfig,
    CTRLConfig,
    DistilBertConfig,
    GPT2Config,
    OpenAIGPTConfig,
    RobertaConfig,
    TransfoXLConfig,
    XLMConfig,
    XLMRobertaConfig,
Aymeric Augustin's avatar
Aymeric Augustin committed
32
33
34
35
36
37
38
39
    XLNetConfig,
)
from .modeling_albert import (
    ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
    AlbertForMaskedLM,
    AlbertForQuestionAnswering,
    AlbertForSequenceClassification,
    AlbertModel,
40
41
)
from .modeling_bert import (
Aymeric Augustin's avatar
Aymeric Augustin committed
42
    BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
43
44
    BertForMaskedLM,
    BertForQuestionAnswering,
Aymeric Augustin's avatar
Aymeric Augustin committed
45
    BertForSequenceClassification,
46
    BertForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
47
    BertModel,
48
)
Aymeric Augustin's avatar
Aymeric Augustin committed
49
50
51
52
53
54
from .modeling_camembert import (
    CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
    CamembertForMaskedLM,
    CamembertForSequenceClassification,
    CamembertForTokenClassification,
    CamembertModel,
55
)
Aymeric Augustin's avatar
Aymeric Augustin committed
56
from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel, CTRLModel
57
from .modeling_distilbert import (
Aymeric Augustin's avatar
Aymeric Augustin committed
58
    DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
59
    DistilBertForMaskedLM,
Aymeric Augustin's avatar
Aymeric Augustin committed
60
    DistilBertForQuestionAnswering,
61
62
    DistilBertForSequenceClassification,
    DistilBertForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
63
    DistilBertModel,
64
)
Aymeric Augustin's avatar
Aymeric Augustin committed
65
66
67
68
69
70
71
72
from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model
from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_roberta import (
    ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
    RobertaForMaskedLM,
    RobertaForSequenceClassification,
    RobertaForTokenClassification,
    RobertaModel,
73
)
Aymeric Augustin's avatar
Aymeric Augustin committed
74
75
76
77
78
79
80
81
from .modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5Model, T5WithLMHeadModel
from .modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TransfoXLModel
from .modeling_xlm import (
    XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
    XLMForQuestionAnswering,
    XLMForSequenceClassification,
    XLMModel,
    XLMWithLMHeadModel,
82
83
)
from .modeling_xlm_roberta import (
Aymeric Augustin's avatar
Aymeric Augustin committed
84
    XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
85
    XLMRobertaForMaskedLM,
Aymeric Augustin's avatar
Aymeric Augustin committed
86
    XLMRobertaForSequenceClassification,
87
    XLMRobertaForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
88
89
90
91
92
93
94
95
96
    XLMRobertaModel,
)
from .modeling_xlnet import (
    XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
    XLNetForQuestionAnswering,
    XLNetForSequenceClassification,
    XLNetForTokenClassification,
    XLNetLMHeadModel,
    XLNetModel,
97
)
thomwolf's avatar
thomwolf committed
98

thomwolf's avatar
thomwolf committed
99

100
logger = logging.getLogger(__name__)
thomwolf's avatar
thomwolf committed
101
102


103
104
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
    (key, value)
105
106
107
108
109
110
111
112
113
114
115
116
117
    for pretrained_map in [
        BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
        OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
        TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
        GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
        CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
        XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
        XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
        ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
        DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
        ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
        CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
        T5_PRETRAINED_MODEL_ARCHIVE_MAP,
118
        XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
119
120
121
    ]
    for key, value, in pretrained_map.items()
)
122
123


thomwolf's avatar
thomwolf committed
124
125
class AutoModel(object):
    r"""
126
        :class:`~transformers.AutoModel` is a generic model class
thomwolf's avatar
thomwolf committed
127
128
        that will be instantiated as one of the base model classes of the library
        when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
129
        or the `AutoModel.from_config(config)` class methods.
thomwolf's avatar
thomwolf committed
130

131
        The `from_pretrained()` method takes care of returning the correct model class instance
thomwolf's avatar
thomwolf committed
132
133
134
135
        using pattern matching on the `pretrained_model_name_or_path` string.

        The base model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
136
            - contains `t5`: T5Model (T5 model)
137
            - contains `distilbert`: DistilBertModel (DistilBERT model)
Elad Segal's avatar
Elad Segal committed
138
            - contains `albert`: AlbertModel (ALBERT model)
Evpok Padding's avatar
Evpok Padding committed
139
            - contains `camembert`: CamembertModel (CamemBERT model)
140
            - contains `xlm-roberta`: XLMRobertaModel (XLM-RoBERTa model)
141
            - contains `roberta`: RobertaModel (RoBERTa model)
142
143
144
145
146
147
            - contains `bert`: BertModel (Bert model)
            - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
            - contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
            - contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
            - contains `xlnet`: XLNetModel (XLNet model)
            - contains `xlm`: XLMModel (XLM model)
Elad Segal's avatar
Elad Segal committed
148
            - contains `ctrl`: CTRLModel (Salesforce CTRL  model)
thomwolf's avatar
thomwolf committed
149

150
        This class cannot be instantiated using `__init__()` (throws an error).
thomwolf's avatar
thomwolf committed
151
    """
152

thomwolf's avatar
thomwolf committed
153
    def __init__(self):
154
155
        raise EnvironmentError(
            "AutoModel is designed to be instantiated "
156
            "using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` or "
157
158
            "`AutoModel.from_config(config)` methods."
        )
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                The model class to instantiate is selected based on the configuration class:
                    - isInstance of `distilbert` configuration class: DistilBertModel (DistilBERT model)
                    - isInstance of `roberta` configuration class: RobertaModel (RoBERTa model)
                    - isInstance of `bert` configuration class: BertModel (Bert model)
                    - isInstance of `openai-gpt` configuration class: OpenAIGPTModel (OpenAI GPT model)
                    - isInstance of `gpt2` configuration class: GPT2Model (OpenAI GPT-2 model)
                    - isInstance of `ctrl` configuration class: CTRLModel (Salesforce CTRL  model)
                    - isInstance of `transfo-xl` configuration class: TransfoXLModel (Transformer-XL model)
                    - isInstance of `xlnet` configuration class: XLNetModel (XLNet model)
                    - isInstance of `xlm` configuration class: XLMModel (XLM model)

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModel.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
        if isinstance(config, DistilBertConfig):
            return DistilBertModel(config)
        elif isinstance(config, RobertaConfig):
            return RobertaModel(config)
        elif isinstance(config, BertConfig):
            return BertModel(config)
        elif isinstance(config, OpenAIGPTConfig):
            return OpenAIGPTModel(config)
        elif isinstance(config, GPT2Config):
            return GPT2Model(config)
        elif isinstance(config, TransfoXLConfig):
            return TransfoXLModel(config)
        elif isinstance(config, XLNetConfig):
            return XLNetModel(config)
        elif isinstance(config, XLMConfig):
            return XLMModel(config)
        elif isinstance(config, CTRLConfig):
            return CTRLModel(config)
        elif isinstance(config, AlbertConfig):
            return AlbertModel(config)
        elif isinstance(config, CamembertConfig):
            return CamembertModel(config)
204
205
        elif isinstance(config, XLMRobertaConfig):
            return XLMRobertaModel(config)
206
        raise ValueError("Unrecognized configuration class {}".format(config))
thomwolf's avatar
thomwolf committed
207
208
209

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
210
        r""" Instantiates one of the base model classes of the library
thomwolf's avatar
thomwolf committed
211
212
        from a pre-trained model configuration.

thomwolf's avatar
thomwolf committed
213
        The model class to instantiate is selected as the first pattern matching
thomwolf's avatar
thomwolf committed
214
        in the `pretrained_model_name_or_path` string (in the following order):
215
            - contains `t5`: T5Model (T5 model)
216
            - contains `distilbert`: DistilBertModel (DistilBERT model)
Elad Segal's avatar
Elad Segal committed
217
            - contains `albert`: AlbertModel (ALBERT model)
Evpok Padding's avatar
Evpok Padding committed
218
            - contains `camembert`: CamembertModel (CamemBERT model)
219
            - contains `xlm-roberta`: XLMRobertaModel (XLM-RoBERTa model)
220
            - contains `roberta`: RobertaModel (RoBERTa model)
221
222
223
224
225
226
            - contains `bert`: BertModel (Bert model)
            - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
            - contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
            - contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
            - contains `xlnet`: XLNetModel (XLNet model)
            - contains `xlm`: XLMModel (XLM model)
Elad Segal's avatar
Elad Segal committed
227
            - contains `ctrl`: CTRLModel (Salesforce CTRL model)
thomwolf's avatar
thomwolf committed
228

thomwolf's avatar
typos  
thomwolf committed
229
            The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
thomwolf's avatar
thomwolf committed
230
231
232
            To train the model, you should first set it back in training mode with `model.train()`

        Params:
thomwolf's avatar
thomwolf committed
233
234
235
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
236
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
237
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
238
239
240
241
242
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

243
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
244
245
246
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
247
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
248
249
250
251
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
thomwolf's avatar
typos  
thomwolf committed
252
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
253
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
254
255

            cache_dir: (`optional`) string:
thomwolf's avatar
thomwolf committed
256
257
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
258
259
260
261

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

262
263
264
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

thomwolf's avatar
thomwolf committed
265
266
267
268
269
270
271
272
273
274
275
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
276
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
thomwolf's avatar
thomwolf committed
277
278
279

        Examples::

thomwolf's avatar
thomwolf committed
280
281
282
283
284
285
286
            model = AutoModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModel.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = AutoModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
thomwolf's avatar
thomwolf committed
287
288

        """
289
        if "t5" in pretrained_model_name_or_path:
290
            return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
291
        elif "distilbert" in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
292
            return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
293
        elif "albert" in pretrained_model_name_or_path:
Elad Segal's avatar
Elad Segal committed
294
            return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
295
        elif "camembert" in pretrained_model_name_or_path:
296
            return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
297
        elif "xlm-roberta" in pretrained_model_name_or_path:
298
            return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
299
        elif "roberta" in pretrained_model_name_or_path:
300
            return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
301
        elif "bert" in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
302
            return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
303
        elif "openai-gpt" in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
304
            return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
305
        elif "gpt2" in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
306
            return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
307
        elif "transfo-xl" in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
308
            return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
309
        elif "xlnet" in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
310
            return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
311
        elif "xlm" in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
312
            return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
313
        elif "ctrl" in pretrained_model_name_or_path:
keskarnitish's avatar
keskarnitish committed
314
            return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
315
316
317
318
319
320
321
        raise ValueError(
            "Unrecognized model identifier in {}. Should contains one of "
            "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
            "'xlm-roberta', 'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'".format(
                pretrained_model_name_or_path
            )
        )
322
323
324
325


class AutoModelWithLMHead(object):
    r"""
326
        :class:`~transformers.AutoModelWithLMHead` is a generic model class
327
328
329
330
331
332
333
334
335
        that will be instantiated as one of the language modeling model classes of the library
        when created with the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)`
        class method.

        The `from_pretrained()` method takes care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
336
            - contains `t5`: T5ModelWithLMHead (T5 model)
337
            - contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
Elad Segal's avatar
Elad Segal committed
338
            - contains `albert`: AlbertForMaskedLM (ALBERT model)
Evpok Padding's avatar
Evpok Padding committed
339
            - contains `camembert`: CamembertForMaskedLM (CamemBERT model)
340
            - contains `xlm-roberta`: XLMRobertaForMaskedLM (XLM-RoBERTa model)
341
342
343
344
345
346
347
            - contains `roberta`: RobertaForMaskedLM (RoBERTa model)
            - contains `bert`: BertForMaskedLM (Bert model)
            - contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
            - contains `gpt2`: GPT2LMHeadModel (OpenAI GPT-2 model)
            - contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
            - contains `xlnet`: XLNetLMHeadModel (XLNet model)
            - contains `xlm`: XLMWithLMHeadModel (XLM model)
Elad Segal's avatar
Elad Segal committed
348
            - contains `ctrl`: CTRLLMHeadModel (Salesforce CTRL model)
349
350
351

        This class cannot be instantiated using `__init__()` (throws an error).
    """
352

353
    def __init__(self):
354
355
        raise EnvironmentError(
            "AutoModelWithLMHead is designed to be instantiated "
356
            "using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` or "
357
358
            "`AutoModelWithLMHead.from_config(config)` methods."
        )
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                The model class to instantiate is selected based on the configuration class:
                    - isInstance of `distilbert` configuration class: DistilBertModel (DistilBERT model)
                    - isInstance of `roberta` configuration class: RobertaModel (RoBERTa model)
                    - isInstance of `bert` configuration class: BertModel (Bert model)
                    - isInstance of `openai-gpt` configuration class: OpenAIGPTModel (OpenAI GPT model)
                    - isInstance of `gpt2` configuration class: GPT2Model (OpenAI GPT-2 model)
                    - isInstance of `ctrl` configuration class: CTRLModel (Salesforce CTRL  model)
                    - isInstance of `transfo-xl` configuration class: TransfoXLModel (Transformer-XL model)
                    - isInstance of `xlnet` configuration class: XLNetModel (XLNet model)
                    - isInstance of `xlm` configuration class: XLMModel (XLM model)

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelWithLMHead.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
        if isinstance(config, DistilBertConfig):
            return DistilBertForMaskedLM(config)
        elif isinstance(config, RobertaConfig):
            return RobertaForMaskedLM(config)
        elif isinstance(config, BertConfig):
            return BertForMaskedLM(config)
        elif isinstance(config, OpenAIGPTConfig):
            return OpenAIGPTLMHeadModel(config)
        elif isinstance(config, GPT2Config):
            return GPT2LMHeadModel(config)
        elif isinstance(config, TransfoXLConfig):
            return TransfoXLLMHeadModel(config)
        elif isinstance(config, XLNetConfig):
            return XLNetLMHeadModel(config)
        elif isinstance(config, XLMConfig):
            return XLMWithLMHeadModel(config)
        elif isinstance(config, CTRLConfig):
            return CTRLLMHeadModel(config)
400
401
        elif isinstance(config, XLMRobertaConfig):
            return XLMRobertaForMaskedLM(config)
402
        raise ValueError("Unrecognized configuration class {}".format(config))
403
404
405
406
407
408
409
410
411
412
413

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the language modeling model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
414
            - contains `t5`: T5ModelWithLMHead (T5 model)
415
            - contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
Elad Segal's avatar
Elad Segal committed
416
            - contains `albert`: AlbertForMaskedLM (ALBERT model)
Evpok Padding's avatar
Evpok Padding committed
417
            - contains `camembert`: CamembertForMaskedLM (CamemBERT model)
418
            - contains `xlm-roberta`: XLMRobertaForMaskedLM (XLM-RoBERTa model)
419
420
421
422
423
424
425
            - contains `roberta`: RobertaForMaskedLM (RoBERTa model)
            - contains `bert`: BertForMaskedLM (Bert model)
            - contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
            - contains `gpt2`: GPT2LMHeadModel (OpenAI GPT-2 model)
            - contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
            - contains `xlnet`: XLNetLMHeadModel (XLNet model)
            - contains `xlm`: XLMWithLMHeadModel (XLM model)
Elad Segal's avatar
Elad Segal committed
426
            - contains `ctrl`: CTRLLMHeadModel (Salesforce CTRL model)
427
428
429
430
431

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

        Params:
thomwolf's avatar
thomwolf committed
432
433
434
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
435
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
436
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
437
438
439
440
441
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

442
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
443
444
445
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
446
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
447
448
449
450
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
451
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
452
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
453
454

            cache_dir: (`optional`) string:
455
456
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
457
458
459

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
460
461
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
thomwolf's avatar
thomwolf committed
462
463
464
465
466
467
468
469
470
471
472
473

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
474
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
475
476
477
478
479
480
481
482
483
484
485
486

        Examples::

            model = AutoModelWithLMHead.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelWithLMHead.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = AutoModelWithLMHead.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
487
        if "t5" in pretrained_model_name_or_path:
488
            return T5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
489
        elif "distilbert" in pretrained_model_name_or_path:
490
            return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
491
        elif "albert" in pretrained_model_name_or_path:
Elad Segal's avatar
Elad Segal committed
492
            return AlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
493
        elif "camembert" in pretrained_model_name_or_path:
494
            return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
495
        elif "xlm-roberta" in pretrained_model_name_or_path:
496
            return XLMRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
497
        elif "roberta" in pretrained_model_name_or_path:
498
            return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
499
        elif "bert" in pretrained_model_name_or_path:
500
            return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
501
        elif "openai-gpt" in pretrained_model_name_or_path:
502
            return OpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
503
        elif "gpt2" in pretrained_model_name_or_path:
504
            return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
505
        elif "transfo-xl" in pretrained_model_name_or_path:
506
            return TransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
507
        elif "xlnet" in pretrained_model_name_or_path:
508
            return XLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
509
        elif "xlm" in pretrained_model_name_or_path:
510
            return XLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
511
        elif "ctrl" in pretrained_model_name_or_path:
keskarnitish's avatar
keskarnitish committed
512
            return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
513
514
515
516
517
518
519
        raise ValueError(
            "Unrecognized model identifier in {}. Should contains one of "
            "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
            "'xlm-roberta', 'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'".format(
                pretrained_model_name_or_path
            )
        )
520
521
522
523


class AutoModelForSequenceClassification(object):
    r"""
524
        :class:`~transformers.AutoModelForSequenceClassification` is a generic model class
525
526
527
528
529
530
531
532
533
        that will be instantiated as one of the sequence classification model classes of the library
        when created with the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)`
        class method.

        The `from_pretrained()` method takes care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
534
            - contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
Elad Segal's avatar
Elad Segal committed
535
            - contains `albert`: AlbertForSequenceClassification (ALBERT model)
Evpok Padding's avatar
Evpok Padding committed
536
            - contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
537
            - contains `xlm-roberta`: XLMRobertaForSequenceClassification (XLM-RoBERTa model)
538
539
540
541
542
543
544
            - contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
            - contains `bert`: BertForSequenceClassification (Bert model)
            - contains `xlnet`: XLNetForSequenceClassification (XLNet model)
            - contains `xlm`: XLMForSequenceClassification (XLM model)

        This class cannot be instantiated using `__init__()` (throws an error).
    """
545

546
    def __init__(self):
547
548
        raise EnvironmentError(
            "AutoModelForSequenceClassification is designed to be instantiated "
549
            "using the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` or "
550
551
            "`AutoModelForSequenceClassification.from_config(config)` methods."
        )
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                The model class to instantiate is selected based on the configuration class:
                    - isInstance of `distilbert` configuration class: DistilBertModel (DistilBERT model)
                    - isInstance of `roberta` configuration class: RobertaModel (RoBERTa model)
                    - isInstance of `bert` configuration class: BertModel (Bert model)
                    - isInstance of `xlnet` configuration class: XLNetModel (XLNet model)
                    - isInstance of `xlm` configuration class: XLMModel (XLM model)

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelForSequenceClassification.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
        if isinstance(config, AlbertConfig):
            return AlbertForSequenceClassification(config)
573
        elif isinstance(config, CamembertConfig):
574
575
576
577
578
579
580
581
582
583
584
            return CamembertForSequenceClassification(config)
        elif isinstance(config, DistilBertConfig):
            return DistilBertForSequenceClassification(config)
        elif isinstance(config, RobertaConfig):
            return RobertaForSequenceClassification(config)
        elif isinstance(config, BertConfig):
            return BertForSequenceClassification(config)
        elif isinstance(config, XLNetConfig):
            return XLNetForSequenceClassification(config)
        elif isinstance(config, XLMConfig):
            return XLMForSequenceClassification(config)
585
586
        elif isinstance(config, XLMRobertaConfig):
            return XLMRobertaForSequenceClassification(config)
587
        raise ValueError("Unrecognized configuration class {}".format(config))
588
589
590
591
592
593
594
595
596
597
598

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the sequence classification model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
599
            - contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
Elad Segal's avatar
Elad Segal committed
600
            - contains `albert`: AlbertForSequenceClassification (ALBERT model)
Evpok Padding's avatar
Evpok Padding committed
601
            - contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
602
            - contains `xlm-roberta`: XLMRobertaForSequenceClassification (XLM-RoBERTa model)
603
604
605
606
607
608
609
610
611
            - contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
            - contains `bert`: BertForSequenceClassification (Bert model)
            - contains `xlnet`: XLNetForSequenceClassification (XLNet model)
            - contains `xlm`: XLMForSequenceClassification (XLM model)

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

        Params:
thomwolf's avatar
thomwolf committed
612
613
614
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
615
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
616
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
617
618
619
620
621
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

622
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
623
624
625
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
626
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
627
628
629
630
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
631
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
632
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
633
634

            cache_dir: (`optional`) string:
635
636
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
637
638
639
640

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

641
642
643
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

thomwolf's avatar
thomwolf committed
644
645
646
647
648
649
650
651
652
653
654
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
655
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
656
657
658
659
660
661
662
663
664
665
666
667

        Examples::

            model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForSequenceClassification.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
        if "distilbert" in pretrained_model_name_or_path:
            return DistilBertForSequenceClassification.from_pretrained(
                pretrained_model_name_or_path, *model_args, **kwargs
            )
        elif "albert" in pretrained_model_name_or_path:
            return AlbertForSequenceClassification.from_pretrained(
                pretrained_model_name_or_path, *model_args, **kwargs
            )
        elif "camembert" in pretrained_model_name_or_path:
            return CamembertForSequenceClassification.from_pretrained(
                pretrained_model_name_or_path, *model_args, **kwargs
            )
        elif "xlm-roberta" in pretrained_model_name_or_path:
            return XLMRobertaForSequenceClassification.from_pretrained(
                pretrained_model_name_or_path, *model_args, **kwargs
            )
        elif "roberta" in pretrained_model_name_or_path:
            return RobertaForSequenceClassification.from_pretrained(
                pretrained_model_name_or_path, *model_args, **kwargs
            )
        elif "bert" in pretrained_model_name_or_path:
689
            return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
690
        elif "xlnet" in pretrained_model_name_or_path:
691
            return XLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
692
        elif "xlm" in pretrained_model_name_or_path:
693
694
            return XLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

695
696
697
698
699
700
        raise ValueError(
            "Unrecognized model identifier in {}. Should contains one of "
            "'bert', 'xlnet', 'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'".format(
                pretrained_model_name_or_path
            )
        )
701
702
703
704


class AutoModelForQuestionAnswering(object):
    r"""
705
        :class:`~transformers.AutoModelForQuestionAnswering` is a generic model class
706
707
708
709
710
711
712
713
714
        that will be instantiated as one of the question answering model classes of the library
        when created with the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)`
        class method.

        The `from_pretrained()` method takes care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
715
            - contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model)
Elad Segal's avatar
Elad Segal committed
716
            - contains `albert`: AlbertForQuestionAnswering (ALBERT model)
717
718
719
720
721
722
            - contains `bert`: BertForQuestionAnswering (Bert model)
            - contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
            - contains `xlm`: XLMForQuestionAnswering (XLM model)

        This class cannot be instantiated using `__init__()` (throws an error).
    """
723

724
    def __init__(self):
725
726
        raise EnvironmentError(
            "AutoModelForQuestionAnswering is designed to be instantiated "
727
            "using the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or "
728
729
            "`AutoModelForQuestionAnswering.from_config(config)` methods."
        )
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.

            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                The model class to instantiate is selected based on the configuration class:
                    - isInstance of `distilbert` configuration class: DistilBertModel (DistilBERT model)
                    - isInstance of `bert` configuration class: BertModel (Bert model)
                    - isInstance of `xlnet` configuration class: XLNetModel (XLNet model)
                    - isInstance of `xlm` configuration class: XLMModel (XLM model)

        Examples::

            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelForSequenceClassification.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
748
        if isinstance(config, AlbertConfig):
749
750
751
752
753
754
755
756
757
758
            return AlbertForQuestionAnswering(config)
        elif isinstance(config, DistilBertConfig):
            return DistilBertForQuestionAnswering(config)
        elif isinstance(config, BertConfig):
            return BertForQuestionAnswering(config)
        elif isinstance(config, XLNetConfig):
            return XLNetForQuestionAnswering(config)
        elif isinstance(config, XLMConfig):
            return XLMForQuestionAnswering(config)
        raise ValueError("Unrecognized configuration class {}".format(config))
759
760
761
762
763
764
765
766
767
768
769

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the question answering model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
770
            - contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model)
Elad Segal's avatar
Elad Segal committed
771
            - contains `albert`: AlbertForQuestionAnswering (ALBERT model)
772
773
774
775
776
777
778
779
            - contains `bert`: BertForQuestionAnswering (Bert model)
            - contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
            - contains `xlm`: XLMForQuestionAnswering (XLM model)

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

        Params:
thomwolf's avatar
thomwolf committed
780
781
782
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
783
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
784
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
785
786
787
788
789
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

790
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
thomwolf's avatar
thomwolf committed
791
792
793
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
794
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
795
796
797
798
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
799
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
800
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
thomwolf's avatar
thomwolf committed
801
802

            cache_dir: (`optional`) string:
803
804
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
thomwolf's avatar
thomwolf committed
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
820
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
821
822
823
824
825
826
827
828
829
830
831
832

        Examples::

            model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForQuestionAnswering.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
833
        if "distilbert" in pretrained_model_name_or_path:
834
            return DistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
835
        elif "albert" in pretrained_model_name_or_path:
Elad Segal's avatar
Elad Segal committed
836
            return AlbertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
837
        elif "bert" in pretrained_model_name_or_path:
838
            return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
839
        elif "xlnet" in pretrained_model_name_or_path:
840
            return XLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
841
        elif "xlm" in pretrained_model_name_or_path:
842
843
            return XLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

844
845
846
847
        raise ValueError(
            "Unrecognized model identifier in {}. Should contains one of "
            "'bert', 'xlnet', 'xlm', 'distilbert', 'albert'".format(pretrained_model_name_or_path)
        )
848
849
850
851


class AutoModelForTokenClassification:
    def __init__(self):
852
853
854
855
856
        raise EnvironmentError(
            "AutoModelForTokenClassification is designed to be instantiated "
            "using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or "
            "`AutoModelForTokenClassification.from_config(config)` methods."
        )
857
858
859
860
861

    @classmethod
    def from_config(cls, config):
        r""" Instantiates one of the base model classes of the library
        from a configuration.
862

863
864
865
866
867
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                The model class to instantiate is selected based on the configuration class:
                    - isInstance of `distilbert` configuration class: DistilBertModel (DistilBERT model)
                    - isInstance of `bert` configuration class: BertModel (Bert model)
                    - isInstance of `xlnet` configuration class: XLNetModel (XLNet model)
868
869
870
                    - isInstance of `camembert` configuration class: CamembertModel (Camembert model)
                    - isInstance of `roberta` configuration class: RobertaModel (Roberta model)

871
        Examples::
872

873
874
875
876
877
            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            model = AutoModelForTokenClassification.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        """
        if isinstance(config, CamembertConfig):
            return CamembertForTokenClassification(config)
878
879
        elif isinstance(config, DistilBertConfig):
            return DistilBertForTokenClassification(config)
880
881
882
883
        elif isinstance(config, BertConfig):
            return BertForTokenClassification(config)
        elif isinstance(config, XLNetConfig):
            return XLNetForTokenClassification(config)
884
885
        elif isinstance(config, RobertaConfig):
            return RobertaForTokenClassification(config)
886
887
        elif isinstance(config, XLMRobertaConfig):
            return XLMRobertaForTokenClassification(config)
888
        raise ValueError("Unrecognized configuration class {}".format(config))
889

890
891
892
893
894
895
896
897
898
899
900
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the question answering model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
            - contains `distilbert`: DistilBertForTokenClassification (DistilBERT model)
901
            - contains `camembert`: CamembertForTokenClassification (Camembert model)
902
903
            - contains `bert`: BertForTokenClassification (Bert model)
            - contains `xlnet`: XLNetForTokenClassification (XLNet model)
904
            - contains `roberta`: RobertaForTokenClassification (Roberta model)
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

        Params:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.

        Examples::

            model = AutoModelForTokenClassification.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModelForTokenClassification.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = AutoModelForTokenClassification.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """
962
963
964
965
966
967
968
969
970
971
972
973
974
        if "camembert" in pretrained_model_name_or_path:
            return CamembertForTokenClassification.from_pretrained(
                pretrained_model_name_or_path, *model_args, **kwargs
            )
        elif "distilbert" in pretrained_model_name_or_path:
            return DistilBertForTokenClassification.from_pretrained(
                pretrained_model_name_or_path, *model_args, **kwargs
            )
        elif "xlm-roberta" in pretrained_model_name_or_path:
            return XLMRobertaForTokenClassification.from_pretrained(
                pretrained_model_name_or_path, *model_args, **kwargs
            )
        elif "roberta" in pretrained_model_name_or_path:
975
            return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
976
        elif "bert" in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
977
            return BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
978
        elif "xlnet" in pretrained_model_name_or_path:
979
980
            return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

981
982
983
984
985
986
        raise ValueError(
            "Unrecognized model identifier in {}. Should contains one of "
            "'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'".format(
                pretrained_model_name_or_path
            )
        )