test_modeling_tf_auto.py 12.8 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17
import copy
import tempfile
Aymeric Augustin's avatar
Aymeric Augustin committed
18
import unittest
thomwolf's avatar
thomwolf committed
19

Kamal Raj's avatar
Kamal Raj committed
20
21
22
23
from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5Config, TapasConfig, is_tf_available
from transformers.testing_utils import (
    DUMMY_UNKNOWN_IDENTIFIER,
    SMALL_MODEL_IDENTIFIER,
24
    RequestCounter,
Kamal Raj's avatar
Kamal Raj committed
25
26
27
28
    require_tensorflow_probability,
    require_tf,
    slow,
)
Aymeric Augustin's avatar
Aymeric Augustin committed
29

30
from ..bert.test_modeling_bert import BertModelTester
31

32

33
if is_tf_available():
34
35
    from transformers import (
        TFAutoModel,
36
37
        TFAutoModelForCausalLM,
        TFAutoModelForMaskedLM,
thomwolf's avatar
thomwolf committed
38
        TFAutoModelForPreTraining,
39
40
41
        TFAutoModelForQuestionAnswering,
        TFAutoModelForSeq2SeqLM,
        TFAutoModelForSequenceClassification,
Kamal Raj's avatar
Kamal Raj committed
42
        TFAutoModelForTableQuestionAnswering,
43
        TFAutoModelForTokenClassification,
44
45
        TFAutoModelWithLMHead,
        TFBertForMaskedLM,
46
        TFBertForPreTraining,
47
        TFBertForQuestionAnswering,
48
49
        TFBertForSequenceClassification,
        TFBertModel,
50
51
        TFFunnelBaseModel,
        TFFunnelModel,
52
        TFGPT2LMHeadModel,
53
        TFRobertaForMaskedLM,
54
        TFT5ForConditionalGeneration,
Kamal Raj's avatar
Kamal Raj committed
55
        TFTapasForQuestionAnswering,
56
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
57
    from transformers.models.auto.modeling_tf_auto import (
58
59
        TF_MODEL_FOR_CAUSAL_LM_MAPPING,
        TF_MODEL_FOR_MASKED_LM_MAPPING,
60
61
62
63
        TF_MODEL_FOR_PRETRAINING_MAPPING,
        TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
        TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
64
        TF_MODEL_MAPPING,
65
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
66
67
68
    from transformers.models.bert.modeling_tf_bert import TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST
    from transformers.models.gpt2.modeling_tf_gpt2 import TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
    from transformers.models.t5.modeling_tf_t5 import TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST
Kamal Raj's avatar
Kamal Raj committed
69
    from transformers.models.tapas.modeling_tf_tapas import TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST
thomwolf's avatar
thomwolf committed
70
71


72
73
74
75
76
77
78
79
80
81
class NewModelConfig(BertConfig):
    model_type = "new-model"


if is_tf_available():

    class TFNewModel(TFBertModel):
        config_class = NewModelConfig


82
@require_tf
thomwolf's avatar
thomwolf committed
83
class TFAutoModelTest(unittest.TestCase):
84
    @slow
thomwolf's avatar
thomwolf committed
85
    def test_model_from_pretrained(self):
Lysandre Debut's avatar
Lysandre Debut committed
86
87
88
89
        model_name = "bert-base-cased"
        config = AutoConfig.from_pretrained(model_name)
        self.assertIsNotNone(config)
        self.assertIsInstance(config, BertConfig)
90

Lysandre Debut's avatar
Lysandre Debut committed
91
92
93
        model = TFAutoModel.from_pretrained(model_name)
        self.assertIsNotNone(model)
        self.assertIsInstance(model, TFBertModel)
thomwolf's avatar
thomwolf committed
94

thomwolf's avatar
thomwolf committed
95
96
    @slow
    def test_model_for_pretraining_from_pretrained(self):
Lysandre Debut's avatar
Lysandre Debut committed
97
98
99
100
101
102
103
104
        model_name = "bert-base-cased"
        config = AutoConfig.from_pretrained(model_name)
        self.assertIsNotNone(config)
        self.assertIsInstance(config, BertConfig)

        model = TFAutoModelForPreTraining.from_pretrained(model_name)
        self.assertIsNotNone(model)
        self.assertIsInstance(model, TFBertForPreTraining)
thomwolf's avatar
thomwolf committed
105

106
107
108
109
110
111
112
113
114
115
116
117
    @slow
    def test_model_for_causal_lm(self):
        for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, GPT2Config)

            model = TFAutoModelForCausalLM.from_pretrained(model_name)
            model, loading_info = TFAutoModelForCausalLM.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFGPT2LMHeadModel)

118
    @slow
thomwolf's avatar
thomwolf committed
119
    def test_lmhead_model_from_pretrained(self):
120
        for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
121
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
122
123
124
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

125
            model = TFAutoModelWithLMHead.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
126
127
128
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForMaskedLM)

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    @slow
    def test_model_for_masked_lm(self):
        for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

            model = TFAutoModelForMaskedLM.from_pretrained(model_name)
            model, loading_info = TFAutoModelForMaskedLM.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForMaskedLM)

    @slow
    def test_model_for_encoder_decoder_lm(self):
        for model_name in TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, T5Config)

            model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
            model, loading_info = TFAutoModelForSeq2SeqLM.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFT5ForConditionalGeneration)

153
    @slow
thomwolf's avatar
thomwolf committed
154
    def test_sequence_classification_model_from_pretrained(self):
155
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
156
        for model_name in ["bert-base-uncased"]:
157
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
158
159
160
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

161
            model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
162
163
164
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForSequenceClassification)

165
    @slow
thomwolf's avatar
thomwolf committed
166
    def test_question_answering_model_from_pretrained(self):
167
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
168
        for model_name in ["bert-base-uncased"]:
169
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
170
171
172
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

173
            model = TFAutoModelForQuestionAnswering.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
174
175
176
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForQuestionAnswering)

Kamal Raj's avatar
Kamal Raj committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    @slow
    @require_tensorflow_probability
    def test_table_question_answering_model_from_pretrained(self):
        for model_name in TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST[5:6]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, TapasConfig)

            model = TFAutoModelForTableQuestionAnswering.from_pretrained(model_name)
            model, loading_info = TFAutoModelForTableQuestionAnswering.from_pretrained(
                model_name, output_loading_info=True
            )
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFTapasForQuestionAnswering)

Julien Chaumond's avatar
Julien Chaumond committed
192
    def test_from_pretrained_identifier(self):
193
        model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
Julien Chaumond's avatar
Julien Chaumond committed
194
        self.assertIsInstance(model, TFBertForMaskedLM)
Julien Plu's avatar
Julien Plu committed
195
196
        self.assertEqual(model.num_parameters(), 14410)
        self.assertEqual(model.num_parameters(only_trainable=True), 14410)
Julien Chaumond's avatar
Julien Chaumond committed
197
198

    def test_from_identifier_from_model_type(self):
199
        model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER)
Julien Chaumond's avatar
Julien Chaumond committed
200
        self.assertIsInstance(model, TFRobertaForMaskedLM)
Julien Plu's avatar
Julien Plu committed
201
202
        self.assertEqual(model.num_parameters(), 14410)
        self.assertEqual(model.num_parameters(only_trainable=True), 14410)
203

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    def test_from_pretrained_with_tuple_values(self):
        # For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
        model = TFAutoModel.from_pretrained("sgugger/funnel-random-tiny")
        self.assertIsInstance(model, TFFunnelModel)

        config = copy.deepcopy(model.config)
        config.architectures = ["FunnelBaseModel"]
        model = TFAutoModel.from_config(config)
        self.assertIsInstance(model, TFFunnelBaseModel)

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)
            model = TFAutoModel.from_pretrained(tmp_dir)
            self.assertIsInstance(model, TFFunnelBaseModel)

219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    def test_new_model_registration(self):
        try:
            AutoConfig.register("new-model", NewModelConfig)

            auto_classes = [
                TFAutoModel,
                TFAutoModelForCausalLM,
                TFAutoModelForMaskedLM,
                TFAutoModelForPreTraining,
                TFAutoModelForQuestionAnswering,
                TFAutoModelForSequenceClassification,
                TFAutoModelForTokenClassification,
            ]

            for auto_class in auto_classes:
                with self.subTest(auto_class.__name__):
                    # Wrong config class will raise an error
                    with self.assertRaises(ValueError):
                        auto_class.register(BertConfig, TFNewModel)
                    auto_class.register(NewModelConfig, TFNewModel)
                    # Trying to register something existing in the Transformers library will raise an error
                    with self.assertRaises(ValueError):
                        auto_class.register(BertConfig, TFBertModel)

                    # Now that the config is registered, it can be used as any other config with the auto-API
                    tiny_config = BertModelTester(self).get_config()
                    config = NewModelConfig(**tiny_config.to_dict())
                    model = auto_class.from_config(config)
                    self.assertIsInstance(model, TFNewModel)

                    with tempfile.TemporaryDirectory() as tmp_dir:
                        model.save_pretrained(tmp_dir)
                        new_model = auto_class.from_pretrained(tmp_dir)
                        self.assertIsInstance(new_model, TFNewModel)

        finally:
            if "new-model" in CONFIG_MAPPING._extra_content:
                del CONFIG_MAPPING._extra_content["new-model"]
            for mapping in (
                TF_MODEL_MAPPING,
                TF_MODEL_FOR_PRETRAINING_MAPPING,
                TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
                TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
                TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
                TF_MODEL_FOR_CAUSAL_LM_MAPPING,
                TF_MODEL_FOR_MASKED_LM_MAPPING,
            ):
                if NewModelConfig in mapping._extra_content:
                    del mapping._extra_content[NewModelConfig]
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

    def test_repo_not_found(self):
        with self.assertRaisesRegex(
            EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
        ):
            _ = TFAutoModel.from_pretrained("bert-base")

    def test_revision_not_found(self):
        with self.assertRaisesRegex(
            EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
        ):
            _ = TFAutoModel.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")

    def test_model_file_not_found(self):
        with self.assertRaisesRegex(
            EnvironmentError,
            "hf-internal-testing/config-no-model does not appear to have a file named tf_model.h5",
        ):
            _ = TFAutoModel.from_pretrained("hf-internal-testing/config-no-model")

    def test_model_from_pt_suggestion(self):
        with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
            _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305

    def test_cached_model_has_minimum_calls_to_head(self):
        # Make sure we have cached the model.
        _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
        with RequestCounter() as counter:
            _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
            self.assertEqual(counter.get_request_count, 0)
            self.assertEqual(counter.head_request_count, 1)
            self.assertEqual(counter.other_request_count, 0)

        # With a sharded checkpoint
        _ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
        with RequestCounter() as counter:
            _ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
            self.assertEqual(counter.get_request_count, 0)
306
            self.assertEqual(counter.head_request_count, 1)
307
            self.assertEqual(counter.other_request_count, 0)