test_modeling_tf_auto.py 12.2 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

Matt's avatar
Matt committed
16
17
from __future__ import annotations

18
19
import copy
import tempfile
Aymeric Augustin's avatar
Aymeric Augustin committed
20
import unittest
thomwolf's avatar
thomwolf committed
21

Kamal Raj's avatar
Kamal Raj committed
22
23
24
25
from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5Config, TapasConfig, is_tf_available
from transformers.testing_utils import (
    DUMMY_UNKNOWN_IDENTIFIER,
    SMALL_MODEL_IDENTIFIER,
26
    RequestCounter,
Kamal Raj's avatar
Kamal Raj committed
27
28
29
30
    require_tensorflow_probability,
    require_tf,
    slow,
)
Aymeric Augustin's avatar
Aymeric Augustin committed
31

32
from ..bert.test_modeling_bert import BertModelTester
33

34

35
if is_tf_available():
36
37
    from transformers import (
        TFAutoModel,
38
39
        TFAutoModelForCausalLM,
        TFAutoModelForMaskedLM,
thomwolf's avatar
thomwolf committed
40
        TFAutoModelForPreTraining,
41
42
43
        TFAutoModelForQuestionAnswering,
        TFAutoModelForSeq2SeqLM,
        TFAutoModelForSequenceClassification,
Kamal Raj's avatar
Kamal Raj committed
44
        TFAutoModelForTableQuestionAnswering,
45
        TFAutoModelForTokenClassification,
46
47
        TFAutoModelWithLMHead,
        TFBertForMaskedLM,
48
        TFBertForPreTraining,
49
        TFBertForQuestionAnswering,
50
51
        TFBertForSequenceClassification,
        TFBertModel,
52
53
        TFFunnelBaseModel,
        TFFunnelModel,
54
        TFGPT2LMHeadModel,
55
        TFRobertaForMaskedLM,
56
        TFT5ForConditionalGeneration,
Kamal Raj's avatar
Kamal Raj committed
57
        TFTapasForQuestionAnswering,
58
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
59
    from transformers.models.auto.modeling_tf_auto import (
60
61
        TF_MODEL_FOR_CAUSAL_LM_MAPPING,
        TF_MODEL_FOR_MASKED_LM_MAPPING,
62
63
64
65
        TF_MODEL_FOR_PRETRAINING_MAPPING,
        TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
        TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
66
        TF_MODEL_MAPPING,
67
    )
thomwolf's avatar
thomwolf committed
68
69


70
71
72
73
74
75
76
77
78
79
class NewModelConfig(BertConfig):
    model_type = "new-model"


if is_tf_available():

    class TFNewModel(TFBertModel):
        config_class = NewModelConfig


80
@require_tf
thomwolf's avatar
thomwolf committed
81
class TFAutoModelTest(unittest.TestCase):
82
    @slow
thomwolf's avatar
thomwolf committed
83
    def test_model_from_pretrained(self):
84
        model_name = "google-bert/bert-base-cased"
Lysandre Debut's avatar
Lysandre Debut committed
85
86
87
        config = AutoConfig.from_pretrained(model_name)
        self.assertIsNotNone(config)
        self.assertIsInstance(config, BertConfig)
88

Lysandre Debut's avatar
Lysandre Debut committed
89
90
91
        model = TFAutoModel.from_pretrained(model_name)
        self.assertIsNotNone(model)
        self.assertIsInstance(model, TFBertModel)
thomwolf's avatar
thomwolf committed
92

thomwolf's avatar
thomwolf committed
93
94
    @slow
    def test_model_for_pretraining_from_pretrained(self):
95
        model_name = "google-bert/bert-base-cased"
Lysandre Debut's avatar
Lysandre Debut committed
96
97
98
99
100
101
102
        config = AutoConfig.from_pretrained(model_name)
        self.assertIsNotNone(config)
        self.assertIsInstance(config, BertConfig)

        model = TFAutoModelForPreTraining.from_pretrained(model_name)
        self.assertIsNotNone(model)
        self.assertIsInstance(model, TFBertForPreTraining)
thomwolf's avatar
thomwolf committed
103

104
105
    @slow
    def test_model_for_causal_lm(self):
106
107
108
109
        model_name = "openai-community/gpt2"
        config = AutoConfig.from_pretrained(model_name)
        self.assertIsNotNone(config)
        self.assertIsInstance(config, GPT2Config)
110

111
112
113
114
        model = TFAutoModelForCausalLM.from_pretrained(model_name)
        model, loading_info = TFAutoModelForCausalLM.from_pretrained(model_name, output_loading_info=True)
        self.assertIsNotNone(model)
        self.assertIsInstance(model, TFGPT2LMHeadModel)
115

116
    @slow
thomwolf's avatar
thomwolf committed
117
    def test_lmhead_model_from_pretrained(self):
118
119
120
121
        model_name = "openai-community/gpt2"
        config = AutoConfig.from_pretrained(model_name)
        self.assertIsNotNone(config)
        self.assertIsInstance(config, BertConfig)
thomwolf's avatar
thomwolf committed
122

123
124
125
        model = TFAutoModelWithLMHead.from_pretrained(model_name)
        self.assertIsNotNone(model)
        self.assertIsInstance(model, TFBertForMaskedLM)
thomwolf's avatar
thomwolf committed
126

127
128
    @slow
    def test_model_for_masked_lm(self):
129
130
131
132
        model_name = "openai-community/gpt2"
        config = AutoConfig.from_pretrained(model_name)
        self.assertIsNotNone(config)
        self.assertIsInstance(config, BertConfig)
133

134
135
136
137
        model = TFAutoModelForMaskedLM.from_pretrained(model_name)
        model, loading_info = TFAutoModelForMaskedLM.from_pretrained(model_name, output_loading_info=True)
        self.assertIsNotNone(model)
        self.assertIsInstance(model, TFBertForMaskedLM)
138
139
140

    @slow
    def test_model_for_encoder_decoder_lm(self):
141
142
143
144
        model_name = "openai-community/gpt2"
        config = AutoConfig.from_pretrained(model_name)
        self.assertIsNotNone(config)
        self.assertIsInstance(config, T5Config)
145

146
147
148
149
        model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
        model, loading_info = TFAutoModelForSeq2SeqLM.from_pretrained(model_name, output_loading_info=True)
        self.assertIsNotNone(model)
        self.assertIsInstance(model, TFT5ForConditionalGeneration)
150

151
    @slow
thomwolf's avatar
thomwolf committed
152
    def test_sequence_classification_model_from_pretrained(self):
153
        #     model_name = 'openai-community/gpt2'
154
        for model_name in ["google-bert/bert-base-uncased"]:
155
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
156
157
158
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

159
            model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
160
161
162
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForSequenceClassification)

163
    @slow
thomwolf's avatar
thomwolf committed
164
    def test_question_answering_model_from_pretrained(self):
165
        #     model_name = 'openai-community/gpt2'
166
        for model_name in ["google-bert/bert-base-uncased"]:
167
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
168
169
170
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

171
            model = TFAutoModelForQuestionAnswering.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
172
173
174
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForQuestionAnswering)

Kamal Raj's avatar
Kamal Raj committed
175
176
177
    @slow
    @require_tensorflow_probability
    def test_table_question_answering_model_from_pretrained(self):
178
179
180
181
        model_name = "google/tapas-base"
        config = AutoConfig.from_pretrained(model_name)
        self.assertIsNotNone(config)
        self.assertIsInstance(config, TapasConfig)
Kamal Raj's avatar
Kamal Raj committed
182

183
184
185
186
187
188
        model = TFAutoModelForTableQuestionAnswering.from_pretrained(model_name)
        model, loading_info = TFAutoModelForTableQuestionAnswering.from_pretrained(
            model_name, output_loading_info=True
        )
        self.assertIsNotNone(model)
        self.assertIsInstance(model, TFTapasForQuestionAnswering)
Kamal Raj's avatar
Kamal Raj committed
189

Julien Chaumond's avatar
Julien Chaumond committed
190
    def test_from_pretrained_identifier(self):
191
        model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
Julien Chaumond's avatar
Julien Chaumond committed
192
        self.assertIsInstance(model, TFBertForMaskedLM)
Julien Plu's avatar
Julien Plu committed
193
194
        self.assertEqual(model.num_parameters(), 14410)
        self.assertEqual(model.num_parameters(only_trainable=True), 14410)
Julien Chaumond's avatar
Julien Chaumond committed
195
196

    def test_from_identifier_from_model_type(self):
197
        model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER)
Julien Chaumond's avatar
Julien Chaumond committed
198
        self.assertIsInstance(model, TFRobertaForMaskedLM)
Julien Plu's avatar
Julien Plu committed
199
200
        self.assertEqual(model.num_parameters(), 14410)
        self.assertEqual(model.num_parameters(only_trainable=True), 14410)
201

202
203
204
205
206
207
208
209
    def test_from_pretrained_with_tuple_values(self):
        # For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
        model = TFAutoModel.from_pretrained("sgugger/funnel-random-tiny")
        self.assertIsInstance(model, TFFunnelModel)

        config = copy.deepcopy(model.config)
        config.architectures = ["FunnelBaseModel"]
        model = TFAutoModel.from_config(config)
Matt's avatar
Matt committed
210
        model.build_in_name_scope()
211

212
213
214
215
216
217
218
        self.assertIsInstance(model, TFFunnelBaseModel)

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)
            model = TFAutoModel.from_pretrained(tmp_dir)
            self.assertIsInstance(model, TFFunnelBaseModel)

219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    def test_new_model_registration(self):
        try:
            AutoConfig.register("new-model", NewModelConfig)

            auto_classes = [
                TFAutoModel,
                TFAutoModelForCausalLM,
                TFAutoModelForMaskedLM,
                TFAutoModelForPreTraining,
                TFAutoModelForQuestionAnswering,
                TFAutoModelForSequenceClassification,
                TFAutoModelForTokenClassification,
            ]

            for auto_class in auto_classes:
                with self.subTest(auto_class.__name__):
                    # Wrong config class will raise an error
                    with self.assertRaises(ValueError):
                        auto_class.register(BertConfig, TFNewModel)
                    auto_class.register(NewModelConfig, TFNewModel)
                    # Trying to register something existing in the Transformers library will raise an error
                    with self.assertRaises(ValueError):
                        auto_class.register(BertConfig, TFBertModel)

                    # Now that the config is registered, it can be used as any other config with the auto-API
                    tiny_config = BertModelTester(self).get_config()
                    config = NewModelConfig(**tiny_config.to_dict())
246

247
                    model = auto_class.from_config(config)
Matt's avatar
Matt committed
248
                    model.build_in_name_scope()
249

250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
                    self.assertIsInstance(model, TFNewModel)

                    with tempfile.TemporaryDirectory() as tmp_dir:
                        model.save_pretrained(tmp_dir)
                        new_model = auto_class.from_pretrained(tmp_dir)
                        self.assertIsInstance(new_model, TFNewModel)

        finally:
            if "new-model" in CONFIG_MAPPING._extra_content:
                del CONFIG_MAPPING._extra_content["new-model"]
            for mapping in (
                TF_MODEL_MAPPING,
                TF_MODEL_FOR_PRETRAINING_MAPPING,
                TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
                TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
                TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
                TF_MODEL_FOR_CAUSAL_LM_MAPPING,
                TF_MODEL_FOR_MASKED_LM_MAPPING,
            ):
                if NewModelConfig in mapping._extra_content:
                    del mapping._extra_content[NewModelConfig]
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

    def test_repo_not_found(self):
        with self.assertRaisesRegex(
            EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
        ):
            _ = TFAutoModel.from_pretrained("bert-base")

    def test_revision_not_found(self):
        with self.assertRaisesRegex(
            EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
        ):
            _ = TFAutoModel.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")

    def test_model_file_not_found(self):
        with self.assertRaisesRegex(
            EnvironmentError,
287
            "hf-internal-testing/config-no-model does not appear to have a file named pytorch_model.bin",
288
289
290
291
292
293
        ):
            _ = TFAutoModel.from_pretrained("hf-internal-testing/config-no-model")

    def test_model_from_pt_suggestion(self):
        with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
            _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
294
295
296
297
298
299

    def test_cached_model_has_minimum_calls_to_head(self):
        # Make sure we have cached the model.
        _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
        with RequestCounter() as counter:
            _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
300
301
302
        self.assertEqual(counter["GET"], 0)
        self.assertEqual(counter["HEAD"], 1)
        self.assertEqual(counter.total_calls, 1)
303
304
305
306
307

        # With a sharded checkpoint
        _ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
        with RequestCounter() as counter:
            _ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
308
309
310
        self.assertEqual(counter["GET"], 0)
        self.assertEqual(counter["HEAD"], 1)
        self.assertEqual(counter.total_calls, 1)