modeling_tf_auto.py 35.5 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """

from __future__ import absolute_import, division, print_function, unicode_literals

import logging

from .modeling_tf_bert import TFBertModel, TFBertForMaskedLM, TFBertForSequenceClassification, TFBertForQuestionAnswering
thomwolf's avatar
thomwolf committed
22
23
24
25
26
27
28
from .modeling_tf_openai import TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel
from .modeling_tf_gpt2 import TFGPT2Model, TFGPT2LMHeadModel
from .modeling_tf_transfo_xl import TFTransfoXLModel, TFTransfoXLLMHeadModel
from .modeling_tf_xlnet import TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSequenceClassification, TFXLNetForQuestionAnswering
from .modeling_tf_xlm import TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple
from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification
from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification
thomwolf's avatar
thomwolf committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

from .file_utils import add_start_docstrings

logger = logging.getLogger(__name__)


class TFAutoModel(object):
    r"""
        :class:`~pytorch_transformers.TFAutoModel` is a generic model class
        that will be instantiated as one of the base model classes of the library
        when created with the `TFAutoModel.from_pretrained(pretrained_model_name_or_path)`
        class method.

        The `from_pretrained()` method takes care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The base model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
thomwolf's avatar
thomwolf committed
47
48
            - contains `distilbert`: TFDistilBertModel (DistilBERT model)
            - contains `roberta`: TFRobertaModel (RoBERTa model)
thomwolf's avatar
thomwolf committed
49
            - contains `bert`: TFBertModel (Bert model)
thomwolf's avatar
thomwolf committed
50
51
52
53
54
            - contains `openai-gpt`: TFOpenAIGPTModel (OpenAI GPT model)
            - contains `gpt2`: TFGPT2Model (OpenAI GPT-2 model)
            - contains `transfo-xl`: TFTransfoXLModel (Transformer-XL model)
            - contains `xlnet`: TFXLNetModel (XLNet model)
            - contains `xlm`: TFXLMModel (XLM model)
thomwolf's avatar
thomwolf committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68

        This class cannot be instantiated using `__init__()` (throws an error).
    """
    def __init__(self):
        raise EnvironmentError("TFAutoModel is designed to be instantiated "
            "using the `TFAutoModel.from_pretrained(pretrained_model_name_or_path)` method.")

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the base model classes of the library
        from a pre-trained model configuration.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
thomwolf's avatar
thomwolf committed
69
70
71
72
73
74
75
76
            - contains `distilbert`: TFDistilBertModel (DistilBERT model)
            - contains `roberta`: TFRobertaModel (RoBERTa model)
            - contains `bert`: TFTFBertModel (Bert model)
            - contains `openai-gpt`: TFOpenAIGPTModel (OpenAI GPT model)
            - contains `gpt2`: TFGPT2Model (OpenAI GPT-2 model)
            - contains `transfo-xl`: TFTransfoXLModel (Transformer-XL model)
            - contains `xlnet`: TFXLNetModel (XLNet model)
            - contains `xlm`: TFXLMModel (XLM model)
thomwolf's avatar
thomwolf committed
77
78
79
80
81
82

        Params:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
83
84
85
86
                - a path or url to a `PyTorch, TF 1.X or TF 2.0 checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In the case of a PyTorch checkpoint, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument.

            from_pt: (`Optional`) Boolean
                Set to True if the Checkpoint is a PyTorch checkpoint.
thomwolf's avatar
thomwolf committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

            config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`:
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
                In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.

        Examples::

            model = TFAutoModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = TFAutoModel.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = TFAutoModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
thomwolf's avatar
thomwolf committed
131
            model = TFAutoModel.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
thomwolf's avatar
thomwolf committed
132
133
134

        """
        if 'distilbert' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
135
            return TFDistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
136
        elif 'roberta' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
137
            return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
138
139
140
        elif 'bert' in pretrained_model_name_or_path:
            return TFBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        elif 'openai-gpt' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
141
            return TFOpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
142
        elif 'gpt2' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
143
            return TFGPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
144
        elif 'transfo-xl' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
145
            return TFTransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
146
        elif 'xlnet' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
147
            return TFXLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
148
        elif 'xlm' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
149
            return TFXLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167

        raise ValueError("Unrecognized model identifier in {}. Should contains one of "
                         "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
                         "'xlm', 'roberta'".format(pretrained_model_name_or_path))


class TFAutoModelWithLMHead(object):
    r"""
        :class:`~pytorch_transformers.TFAutoModelWithLMHead` is a generic model class
        that will be instantiated as one of the language modeling model classes of the library
        when created with the `TFAutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)`
        class method.

        The `from_pretrained()` method takes care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
thomwolf's avatar
thomwolf committed
168
169
            - contains `distilbert`: TFDistilBertForMaskedLM (DistilBERT model)
            - contains `roberta`: TFRobertaForMaskedLM (RoBERTa model)
thomwolf's avatar
thomwolf committed
170
            - contains `bert`: TFBertForMaskedLM (Bert model)
thomwolf's avatar
thomwolf committed
171
172
173
174
175
            - contains `openai-gpt`: TFOpenAIGPTLMHeadModel (OpenAI GPT model)
            - contains `gpt2`: TFGPT2LMHeadModel (OpenAI GPT-2 model)
            - contains `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model)
            - contains `xlnet`: TFXLNetLMHeadModel (XLNet model)
            - contains `xlm`: TFXLMWithLMHeadModel (XLM model)
thomwolf's avatar
thomwolf committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

        This class cannot be instantiated using `__init__()` (throws an error).
    """
    def __init__(self):
        raise EnvironmentError("TFAutoModelWithLMHead is designed to be instantiated "
            "using the `TFAutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` method.")

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the language modeling model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
thomwolf's avatar
thomwolf committed
193
194
            - contains `distilbert`: TFDistilBertForMaskedLM (DistilBERT model)
            - contains `roberta`: TFRobertaForMaskedLM (RoBERTa model)
thomwolf's avatar
thomwolf committed
195
            - contains `bert`: TFBertForMaskedLM (Bert model)
thomwolf's avatar
thomwolf committed
196
197
198
199
200
            - contains `openai-gpt`: TFOpenAIGPTLMHeadModel (OpenAI GPT model)
            - contains `gpt2`: TFGPT2LMHeadModel (OpenAI GPT-2 model)
            - contains `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model)
            - contains `xlnet`: TFXLNetLMHeadModel (XLNet model)
            - contains `xlm`: TFXLMWithLMHeadModel (XLM model)
thomwolf's avatar
thomwolf committed
201
202
203
204
205
206

        Params:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
207
208
209
210
                - a path or url to a `PyTorch, TF 1.X or TF 2.0 checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In the case of a PyTorch checkpoint, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument.

            from_pt: (`Optional`) Boolean
                Set to True if the Checkpoint is a PyTorch checkpoint.
thomwolf's avatar
thomwolf committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

            config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`:
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
                In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.

        Examples::

            model = TFAutoModelWithLMHead.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = TFAutoModelWithLMHead.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = TFAutoModelWithLMHead.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
thomwolf's avatar
thomwolf committed
255
            model = TFAutoModelWithLMHead.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
thomwolf's avatar
thomwolf committed
256
257
258

        """
        if 'distilbert' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
259
            return TFDistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
260
        elif 'roberta' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
261
            return TFRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
262
263
264
        elif 'bert' in pretrained_model_name_or_path:
            return TFBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        elif 'openai-gpt' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
265
            return TFOpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
266
        elif 'gpt2' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
267
            return TFGPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
268
        elif 'transfo-xl' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
269
            return TFTransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
270
        elif 'xlnet' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
271
            return TFXLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
272
        elif 'xlm' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
273
            return TFXLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

        raise ValueError("Unrecognized model identifier in {}. Should contains one of "
                         "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
                         "'xlm', 'roberta'".format(pretrained_model_name_or_path))


class TFAutoModelForSequenceClassification(object):
    r"""
        :class:`~pytorch_transformers.TFAutoModelForSequenceClassification` is a generic model class
        that will be instantiated as one of the sequence classification model classes of the library
        when created with the `TFAutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)`
        class method.

        The `from_pretrained()` method takes care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
thomwolf's avatar
thomwolf committed
292
293
            - contains `distilbert`: TFDistilBertForSequenceClassification (DistilBERT model)
            - contains `roberta`: TFRobertaForSequenceClassification (RoBERTa model)
thomwolf's avatar
thomwolf committed
294
            - contains `bert`: TFBertForSequenceClassification (Bert model)
thomwolf's avatar
thomwolf committed
295
296
            - contains `xlnet`: TFXLNetForSequenceClassification (XLNet model)
            - contains `xlm`: TFXLMForSequenceClassification (XLM model)
thomwolf's avatar
thomwolf committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313

        This class cannot be instantiated using `__init__()` (throws an error).
    """
    def __init__(self):
        raise EnvironmentError("TFAutoModelWithLMHead is designed to be instantiated "
            "using the `TFAutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` method.")

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the sequence classification model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
thomwolf's avatar
thomwolf committed
314
315
            - contains `distilbert`: TFDistilBertForSequenceClassification (DistilBERT model)
            - contains `roberta`: TFRobertaForSequenceClassification (RoBERTa model)
thomwolf's avatar
thomwolf committed
316
            - contains `bert`: TFBertForSequenceClassification (Bert model)
thomwolf's avatar
thomwolf committed
317
318
            - contains `xlnet`: TFXLNetForSequenceClassification (XLNet model)
            - contains `xlm`: TFXLMForSequenceClassification (XLM model)
thomwolf's avatar
thomwolf committed
319
320
321
322
323
324
325
326
327

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

        Params:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
328
329
330
331
                - a path or url to a `PyTorch, TF 1.X or TF 2.0 checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In the case of a PyTorch checkpoint, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument.

            from_pt: (`Optional`) Boolean
                Set to True if the Checkpoint is a PyTorch checkpoint.
thomwolf's avatar
thomwolf committed
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

            config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`:
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
                In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.

        Examples::

            model = TFAutoModelForSequenceClassification.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = TFAutoModelForSequenceClassification.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = TFAutoModelForSequenceClassification.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
thomwolf's avatar
thomwolf committed
376
            model = TFAutoModelForSequenceClassification.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
thomwolf's avatar
thomwolf committed
377
378
379

        """
        if 'distilbert' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
380
            return TFDistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
381
        elif 'roberta' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
382
            return TFRobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
383
384
385
        elif 'bert' in pretrained_model_name_or_path:
            return TFBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        elif 'xlnet' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
386
            return TFXLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
387
        elif 'xlm' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
388
            return TFXLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405

        raise ValueError("Unrecognized model identifier in {}. Should contains one of "
                         "'bert', 'xlnet', 'xlm', 'roberta'".format(pretrained_model_name_or_path))


class TFAutoModelForQuestionAnswering(object):
    r"""
        :class:`~pytorch_transformers.TFAutoModelForQuestionAnswering` is a generic model class
        that will be instantiated as one of the question answering model classes of the library
        when created with the `TFAutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)`
        class method.

        The `from_pretrained()` method takes care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
thomwolf's avatar
thomwolf committed
406
            - contains `distilbert`: TFDistilBertForQuestionAnswering (DistilBERT model)
thomwolf's avatar
thomwolf committed
407
            - contains `bert`: TFBertForQuestionAnswering (Bert model)
thomwolf's avatar
thomwolf committed
408
409
            - contains `xlnet`: TFXLNetForQuestionAnswering (XLNet model)
            - contains `xlm`: TFXLMForQuestionAnswering (XLM model)
thomwolf's avatar
thomwolf committed
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426

        This class cannot be instantiated using `__init__()` (throws an error).
    """
    def __init__(self):
        raise EnvironmentError("TFAutoModelWithLMHead is designed to be instantiated "
            "using the `TFAutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` method.")

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiates one of the question answering model classes of the library
        from a pre-trained model configuration.

        The `from_pretrained()` method takes care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
thomwolf's avatar
thomwolf committed
427
            - contains `distilbert`: TFDistilBertForQuestionAnswering (DistilBERT model)
thomwolf's avatar
thomwolf committed
428
            - contains `bert`: TFBertForQuestionAnswering (Bert model)
thomwolf's avatar
thomwolf committed
429
430
            - contains `xlnet`: TFXLNetForQuestionAnswering (XLNet model)
            - contains `xlm`: TFXLMForQuestionAnswering (XLM model)
thomwolf's avatar
thomwolf committed
431
432
433
434
435
436
437
438
439

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with `model.train()`

        Params:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
                - a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
440
441
442
443
                - a path or url to a `PyTorch, TF 1.X or TF 2.0 checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In the case of a PyTorch checkpoint, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument.

            from_pt: (`Optional`) Boolean
                Set to True if the Checkpoint is a PyTorch checkpoint.
thomwolf's avatar
thomwolf committed
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

            config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`:
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                - the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
                In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.

        Examples::

            model = TFAutoModelForQuestionAnswering.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = TFAutoModelForQuestionAnswering.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = TFAutoModelForQuestionAnswering.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
thomwolf's avatar
thomwolf committed
488
            model = TFAutoModelForQuestionAnswering.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
thomwolf's avatar
thomwolf committed
489
490
491

        """
        if 'distilbert' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
492
            return TFDistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
493
494
495
        elif 'bert' in pretrained_model_name_or_path:
            return TFBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        elif 'xlnet' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
496
            return TFXLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
497
        elif 'xlm' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
498
            return TFXLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
thomwolf's avatar
thomwolf committed
499
500
501

        raise ValueError("Unrecognized model identifier in {}. Should contains one of "
                         "'bert', 'xlnet', 'xlm'".format(pretrained_model_name_or_path))