test_modeling_tf_auto.py 11.9 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17
import copy
import tempfile
Aymeric Augustin's avatar
Aymeric Augustin committed
18
import unittest
thomwolf's avatar
thomwolf committed
19

Kamal Raj's avatar
Kamal Raj committed
20
21
22
23
24
25
26
27
from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5Config, TapasConfig, is_tf_available
from transformers.testing_utils import (
    DUMMY_UNKNOWN_IDENTIFIER,
    SMALL_MODEL_IDENTIFIER,
    require_tensorflow_probability,
    require_tf,
    slow,
)
Aymeric Augustin's avatar
Aymeric Augustin committed
28

29
from ..bert.test_modeling_bert import BertModelTester
30

31

32
if is_tf_available():
33
34
    from transformers import (
        TFAutoModel,
35
36
        TFAutoModelForCausalLM,
        TFAutoModelForMaskedLM,
thomwolf's avatar
thomwolf committed
37
        TFAutoModelForPreTraining,
38
39
40
        TFAutoModelForQuestionAnswering,
        TFAutoModelForSeq2SeqLM,
        TFAutoModelForSequenceClassification,
Kamal Raj's avatar
Kamal Raj committed
41
        TFAutoModelForTableQuestionAnswering,
42
        TFAutoModelForTokenClassification,
43
44
        TFAutoModelWithLMHead,
        TFBertForMaskedLM,
45
        TFBertForPreTraining,
46
        TFBertForQuestionAnswering,
47
48
        TFBertForSequenceClassification,
        TFBertModel,
49
50
        TFFunnelBaseModel,
        TFFunnelModel,
51
        TFGPT2LMHeadModel,
52
        TFRobertaForMaskedLM,
53
        TFT5ForConditionalGeneration,
Kamal Raj's avatar
Kamal Raj committed
54
        TFTapasForQuestionAnswering,
55
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
56
    from transformers.models.auto.modeling_tf_auto import (
57
58
        TF_MODEL_FOR_CAUSAL_LM_MAPPING,
        TF_MODEL_FOR_MASKED_LM_MAPPING,
59
60
61
62
        TF_MODEL_FOR_PRETRAINING_MAPPING,
        TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
        TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
63
        TF_MODEL_MAPPING,
64
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
65
66
67
    from transformers.models.bert.modeling_tf_bert import TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST
    from transformers.models.gpt2.modeling_tf_gpt2 import TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
    from transformers.models.t5.modeling_tf_t5 import TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST
Kamal Raj's avatar
Kamal Raj committed
68
    from transformers.models.tapas.modeling_tf_tapas import TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST
thomwolf's avatar
thomwolf committed
69
70


71
72
73
74
75
76
77
78
79
80
class NewModelConfig(BertConfig):
    model_type = "new-model"


if is_tf_available():

    class TFNewModel(TFBertModel):
        config_class = NewModelConfig


81
@require_tf
thomwolf's avatar
thomwolf committed
82
class TFAutoModelTest(unittest.TestCase):
83
    @slow
thomwolf's avatar
thomwolf committed
84
    def test_model_from_pretrained(self):
Lysandre Debut's avatar
Lysandre Debut committed
85
86
87
88
        model_name = "bert-base-cased"
        config = AutoConfig.from_pretrained(model_name)
        self.assertIsNotNone(config)
        self.assertIsInstance(config, BertConfig)
89

Lysandre Debut's avatar
Lysandre Debut committed
90
91
92
        model = TFAutoModel.from_pretrained(model_name)
        self.assertIsNotNone(model)
        self.assertIsInstance(model, TFBertModel)
thomwolf's avatar
thomwolf committed
93

thomwolf's avatar
thomwolf committed
94
95
    @slow
    def test_model_for_pretraining_from_pretrained(self):
Lysandre Debut's avatar
Lysandre Debut committed
96
97
98
99
100
101
102
103
        model_name = "bert-base-cased"
        config = AutoConfig.from_pretrained(model_name)
        self.assertIsNotNone(config)
        self.assertIsInstance(config, BertConfig)

        model = TFAutoModelForPreTraining.from_pretrained(model_name)
        self.assertIsNotNone(model)
        self.assertIsInstance(model, TFBertForPreTraining)
thomwolf's avatar
thomwolf committed
104

105
106
107
108
109
110
111
112
113
114
115
116
    @slow
    def test_model_for_causal_lm(self):
        for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, GPT2Config)

            model = TFAutoModelForCausalLM.from_pretrained(model_name)
            model, loading_info = TFAutoModelForCausalLM.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFGPT2LMHeadModel)

117
    @slow
thomwolf's avatar
thomwolf committed
118
    def test_lmhead_model_from_pretrained(self):
119
        for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
120
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
121
122
123
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

124
            model = TFAutoModelWithLMHead.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
125
126
127
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForMaskedLM)

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    @slow
    def test_model_for_masked_lm(self):
        for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

            model = TFAutoModelForMaskedLM.from_pretrained(model_name)
            model, loading_info = TFAutoModelForMaskedLM.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForMaskedLM)

    @slow
    def test_model_for_encoder_decoder_lm(self):
        for model_name in TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, T5Config)

            model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
            model, loading_info = TFAutoModelForSeq2SeqLM.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFT5ForConditionalGeneration)

152
    @slow
thomwolf's avatar
thomwolf committed
153
    def test_sequence_classification_model_from_pretrained(self):
154
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
155
        for model_name in ["bert-base-uncased"]:
156
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
157
158
159
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

160
            model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
161
162
163
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForSequenceClassification)

164
    @slow
thomwolf's avatar
thomwolf committed
165
    def test_question_answering_model_from_pretrained(self):
166
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
167
        for model_name in ["bert-base-uncased"]:
168
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
169
170
171
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

172
            model = TFAutoModelForQuestionAnswering.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
173
174
175
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForQuestionAnswering)

Kamal Raj's avatar
Kamal Raj committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    @slow
    @require_tensorflow_probability
    def test_table_question_answering_model_from_pretrained(self):
        for model_name in TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST[5:6]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, TapasConfig)

            model = TFAutoModelForTableQuestionAnswering.from_pretrained(model_name)
            model, loading_info = TFAutoModelForTableQuestionAnswering.from_pretrained(
                model_name, output_loading_info=True
            )
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFTapasForQuestionAnswering)

Julien Chaumond's avatar
Julien Chaumond committed
191
    def test_from_pretrained_identifier(self):
192
        model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
Julien Chaumond's avatar
Julien Chaumond committed
193
        self.assertIsInstance(model, TFBertForMaskedLM)
Julien Plu's avatar
Julien Plu committed
194
195
        self.assertEqual(model.num_parameters(), 14410)
        self.assertEqual(model.num_parameters(only_trainable=True), 14410)
Julien Chaumond's avatar
Julien Chaumond committed
196
197

    def test_from_identifier_from_model_type(self):
198
        model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER)
Julien Chaumond's avatar
Julien Chaumond committed
199
        self.assertIsInstance(model, TFRobertaForMaskedLM)
Julien Plu's avatar
Julien Plu committed
200
201
        self.assertEqual(model.num_parameters(), 14410)
        self.assertEqual(model.num_parameters(only_trainable=True), 14410)
202

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    def test_from_pretrained_with_tuple_values(self):
        # For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
        model = TFAutoModel.from_pretrained("sgugger/funnel-random-tiny")
        self.assertIsInstance(model, TFFunnelModel)

        config = copy.deepcopy(model.config)
        config.architectures = ["FunnelBaseModel"]
        model = TFAutoModel.from_config(config)
        self.assertIsInstance(model, TFFunnelBaseModel)

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)
            model = TFAutoModel.from_pretrained(tmp_dir)
            self.assertIsInstance(model, TFFunnelBaseModel)

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
    def test_new_model_registration(self):
        try:
            AutoConfig.register("new-model", NewModelConfig)

            auto_classes = [
                TFAutoModel,
                TFAutoModelForCausalLM,
                TFAutoModelForMaskedLM,
                TFAutoModelForPreTraining,
                TFAutoModelForQuestionAnswering,
                TFAutoModelForSequenceClassification,
                TFAutoModelForTokenClassification,
            ]

            for auto_class in auto_classes:
                with self.subTest(auto_class.__name__):
                    # Wrong config class will raise an error
                    with self.assertRaises(ValueError):
                        auto_class.register(BertConfig, TFNewModel)
                    auto_class.register(NewModelConfig, TFNewModel)
                    # Trying to register something existing in the Transformers library will raise an error
                    with self.assertRaises(ValueError):
                        auto_class.register(BertConfig, TFBertModel)

                    # Now that the config is registered, it can be used as any other config with the auto-API
                    tiny_config = BertModelTester(self).get_config()
                    config = NewModelConfig(**tiny_config.to_dict())
                    model = auto_class.from_config(config)
                    self.assertIsInstance(model, TFNewModel)

                    with tempfile.TemporaryDirectory() as tmp_dir:
                        model.save_pretrained(tmp_dir)
                        new_model = auto_class.from_pretrained(tmp_dir)
                        self.assertIsInstance(new_model, TFNewModel)

        finally:
            if "new-model" in CONFIG_MAPPING._extra_content:
                del CONFIG_MAPPING._extra_content["new-model"]
            for mapping in (
                TF_MODEL_MAPPING,
                TF_MODEL_FOR_PRETRAINING_MAPPING,
                TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
                TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
                TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
                TF_MODEL_FOR_CAUSAL_LM_MAPPING,
                TF_MODEL_FOR_MASKED_LM_MAPPING,
            ):
                if NewModelConfig in mapping._extra_content:
                    del mapping._extra_content[NewModelConfig]
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

    def test_repo_not_found(self):
        with self.assertRaisesRegex(
            EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
        ):
            _ = TFAutoModel.from_pretrained("bert-base")

    def test_revision_not_found(self):
        with self.assertRaisesRegex(
            EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
        ):
            _ = TFAutoModel.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")

    def test_model_file_not_found(self):
        with self.assertRaisesRegex(
            EnvironmentError,
            "hf-internal-testing/config-no-model does not appear to have a file named tf_model.h5",
        ):
            _ = TFAutoModel.from_pretrained("hf-internal-testing/config-no-model")

    def test_model_from_pt_suggestion(self):
        with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
            _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")