modeling_auto.py 15.1 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """

from __future__ import absolute_import, division, print_function, unicode_literals

import logging

thomwolf's avatar
thomwolf committed
21
22
23
24
25
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn.parameter import Parameter

thomwolf's avatar
thomwolf committed
26
27
28
29
30
31
from .modeling_bert import BertConfig, BertModel
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel
from .modeling_gpt2 import GPT2Config, GPT2Model
from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel
from .modeling_xlnet import XLNetConfig, XLNetModel
from .modeling_xlm import XLMConfig, XLMModel
32
from .modeling_roberta import RobertaConfig, RobertaModel
LysandreJik's avatar
LysandreJik committed
33
from .modeling_dilbert import DilBertConfig, DilBertModel
thomwolf's avatar
thomwolf committed
34

thomwolf's avatar
thomwolf committed
35
36
from .modeling_utils import PreTrainedModel, SequenceSummary

thomwolf's avatar
thomwolf committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
logger = logging.getLogger(__name__)

class AutoConfig(object):
    r""":class:`~pytorch_transformers.AutoConfig` is a generic configuration class
        that will be instantiated as one of the configuration classes of the library
        when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)`
        class method.

        The `from_pretrained()` method take care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The base model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
            - contains `bert`: BertConfig (Bert model)
            - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
            - contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
            - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
            - contains `xlnet`: XLNetConfig (XLNet model)
            - contains `xlm`: XLMConfig (XLM model)
56
            - contains `roberta`: RobertaConfig (RoBERTa model)
thomwolf's avatar
thomwolf committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

        This class cannot be instantiated using `__init__()` (throw an error).
    """
    def __init__(self):
        raise EnvironmentError("AutoConfig is designed to be instantiated "
            "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method.")

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        r""" Instantiate a one of the configuration classes of the library
        from a pre-trained model configuration.

        The configuration class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
            - contains `bert`: BertConfig (Bert model)
            - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
            - contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
            - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
            - contains `xlnet`: XLNetConfig (XLNet model)
            - contains `xlm`: XLMConfig (XLM model)
77
            - contains `roberta`: RobertaConfig (RoBERTa model)
thomwolf's avatar
thomwolf committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

        Params:
            **pretrained_model_name_or_path**: either:
                - a string with the `shortcut name` of a pre-trained model configuration to load from cache
                    or download and cache if not already stored in cache (e.g. 'bert-base-uncased').
                - a path to a `directory` containing a configuration file saved
                    using the `save_pretrained(save_directory)` method.
                - a path or url to a saved configuration `file`.
            **cache_dir**: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
            **return_unused_kwargs**: (`optional`) bool:
                - If False, then this function returns just the final configuration object.
                - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs`
                is a dictionary consisting of the key/value pairs whose keys are not configuration attributes:
                ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
            **kwargs**: (`optional`) dict:
                Dictionary of key/value pairs with which to update the configuration object after loading.
                - The values in kwargs of any keys which are configuration attributes will be used
                to override the loaded values.
                - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
                by the `return_unused_kwargs` keyword parameter.

        Examples::

thomwolf's avatar
thomwolf committed
103
104
105
106
107
108
109
110
111
            config = AutoConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            config = AutoConfig.from_pretrained('./test/bert_saved_model/')  # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
            config = AutoConfig.from_pretrained('./test/bert_saved_model/my_configuration.json')
            config = AutoConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
            assert config.output_attention == True
            config, unused_kwargs = AutoConfig.from_pretrained('bert-base-uncased', output_attention=True,
                                                               foo=False, return_unused_kwargs=True)
            assert config.output_attention == True
            assert unused_kwargs == {'foo': False}
thomwolf's avatar
thomwolf committed
112
113

        """
LysandreJik's avatar
LysandreJik committed
114
        if 'dilbert' in pretrained_model_name_or_path:
VictorSanh's avatar
VictorSanh committed
115
            return DilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
LysandreJik's avatar
LysandreJik committed
116
        elif 'roberta' in pretrained_model_name_or_path:
117
118
            return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        elif 'bert' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
            return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        elif 'openai-gpt' in pretrained_model_name_or_path:
            return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        elif 'gpt2' in pretrained_model_name_or_path:
            return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
        elif 'transfo-xl' in pretrained_model_name_or_path:
            return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        elif 'xlnet' in pretrained_model_name_or_path:
            return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        elif 'xlm' in pretrained_model_name_or_path:
            return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

        raise ValueError("Unrecognized model identifier in {}. Should contains one of "
                         "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
133
                         "'xlm', 'roberta'".format(pretrained_model_name_or_path))
thomwolf's avatar
thomwolf committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147


class AutoModel(object):
    r"""
        :class:`~pytorch_transformers.AutoModel` is a generic model class
        that will be instantiated as one of the base model classes of the library
        when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
        class method.

        The `from_pretrained()` method take care of returning the correct model class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The base model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
148
149
150
151
152
153
154
            - contains `bert`: BertModel (Bert model)
            - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
            - contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
            - contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
            - contains `xlnet`: XLNetModel (XLNet model)
            - contains `xlm`: XLMModel (XLM model)
            - contains `roberta`: RobertaModel (RoBERTa model)
thomwolf's avatar
thomwolf committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168

        This class cannot be instantiated using `__init__()` (throw an error).
    """
    def __init__(self):
        raise EnvironmentError("AutoModel is designed to be instantiated "
            "using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` method.")

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        r""" Instantiate a one of the base model classes of the library
        from a pre-trained model configuration.

        The base model class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
169
170
171
172
173
174
175
            - contains `bert`: BertModel (Bert model)
            - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
            - contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
            - contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
            - contains `xlnet`: XLNetModel (XLNet model)
            - contains `xlm`: XLMModel (XLM model)
            - contains `roberta`: RobertaModel (RoBERTa model)
thomwolf's avatar
thomwolf committed
176

thomwolf's avatar
typos  
thomwolf committed
177
            The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
thomwolf's avatar
thomwolf committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
            To train the model, you should first set it back in training mode with `model.train()`

        Params:
            **pretrained_model_name_or_path**: either:
                - a string with the `shortcut name` of a pre-trained model to load from cache
                    or download and cache if not already stored in cache (e.g. 'bert-base-uncased').
                - a path to a `directory` containing a configuration file saved
                    using the `save_pretrained(save_directory)` method.
                - a path or url to a tensorflow index checkpoint `file` (e.g. `./tf_model/model.ckpt.index`).
                    In this case, ``from_tf`` should be set to True and a configuration object should be
                    provided as `config` argument. This loading option is slower than converting the TensorFlow
                    checkpoint in a PyTorch model using the provided conversion scripts and loading
                    the PyTorch model afterwards.
            **model_args**: (`optional`) Sequence:
                All remaning positional arguments will be passed to the underlying model's __init__ function
            **config**: an optional configuration for the model to use instead of an automatically loaded configuation.
                Configuration can be automatically loaded when:
                - the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or
                - the model was saved using the `save_pretrained(save_directory)` (loaded by suppling the save directory).
            **state_dict**: an optional state dictionnary for the model to use instead of a state dictionary loaded
                from saved weights file.
thomwolf's avatar
typos  
thomwolf committed
199
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
thomwolf's avatar
thomwolf committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
                In this case though, you should check if using `save_pretrained(dir)` and `from_pretrained(save_directory)` is not
                a simpler option.
            **cache_dir**: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
            **output_loading_info**: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
            **kwargs**: (`optional`) dict:
                Dictionary of key, values to update the configuration object after loading.
                Can be used to override selected configuration parameters. E.g. ``output_attention=True``.

               - If a configuration is provided with `config`, **kwargs will be directly passed
                 to the underlying model's __init__ method.
               - If a configuration is not provided, **kwargs will be first passed to the pretrained
                 model configuration class loading function (`PretrainedConfig.from_pretrained`).
                 Each key of **kwargs that corresponds to a configuration attribute
                 will be used to override said attribute with the supplied **kwargs value.
                 Remaining keys that do not correspond to any configuration attribute will
                 be passed to the underlying model's __init__ function.

        Examples::

thomwolf's avatar
thomwolf committed
222
223
224
225
226
227
228
            model = AutoModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = AutoModel.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = AutoModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
thomwolf's avatar
thomwolf committed
229
230

        """
LysandreJik's avatar
LysandreJik committed
231
232
233
        if 'dilbert' in pretrained_model_name_or_path:
            return DilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        elif 'roberta' in pretrained_model_name_or_path:
234
235
            return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        elif 'bert' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
            return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        elif 'openai-gpt' in pretrained_model_name_or_path:
            return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        elif 'gpt2' in pretrained_model_name_or_path:
            return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        elif 'transfo-xl' in pretrained_model_name_or_path:
            return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        elif 'xlnet' in pretrained_model_name_or_path:
            return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        elif 'xlm' in pretrained_model_name_or_path:
            return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

        raise ValueError("Unrecognized model identifier in {}. Should contains one of "
                         "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
250
                         "'xlm', 'roberta'".format(pretrained_model_name_or_path))