test_modeling_tf_auto.py 8.34 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

thomwolf's avatar
thomwolf committed
16

Aymeric Augustin's avatar
Aymeric Augustin committed
17
import unittest
thomwolf's avatar
thomwolf committed
18

19
from transformers import is_tf_available
20
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_tf, slow
Aymeric Augustin's avatar
Aymeric Augustin committed
21

22

23
if is_tf_available():
24
25
26
    from transformers import (
        AutoConfig,
        BertConfig,
27
28
        GPT2Config,
        T5Config,
29
30
        TFAutoModel,
        TFBertModel,
thomwolf's avatar
thomwolf committed
31
32
        TFAutoModelForPreTraining,
        TFBertForPreTraining,
33
34
        TFAutoModelWithLMHead,
        TFBertForMaskedLM,
Julien Chaumond's avatar
Julien Chaumond committed
35
        TFRobertaForMaskedLM,
36
37
38
39
        TFAutoModelForSequenceClassification,
        TFBertForSequenceClassification,
        TFAutoModelForQuestionAnswering,
        TFBertForQuestionAnswering,
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        TFAutoModelForCausalLM,
        TFGPT2LMHeadModel,
        TFAutoModelForMaskedLM,
        TFAutoModelForSeq2SeqLM,
        TFT5ForConditionalGeneration,
    )
    from transformers.modeling_tf_bert import TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST
    from transformers.modeling_tf_gpt2 import TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
    from transformers.modeling_tf_t5 import TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST
    from transformers.modeling_tf_auto import (
        TF_MODEL_MAPPING,
        TF_MODEL_FOR_PRETRAINING_MAPPING,
        TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
        TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
        TF_MODEL_WITH_LM_HEAD_MAPPING,
        TF_MODEL_FOR_CAUSAL_LM_MAPPING,
        TF_MODEL_FOR_MASKED_LM_MAPPING,
        TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
59
    )
thomwolf's avatar
thomwolf committed
60
61


62
@require_tf
thomwolf's avatar
thomwolf committed
63
class TFAutoModelTest(unittest.TestCase):
64
    @slow
thomwolf's avatar
thomwolf committed
65
    def test_model_from_pretrained(self):
thomwolf's avatar
thomwolf committed
66
        import h5py
67

thomwolf's avatar
thomwolf committed
68
69
        self.assertTrue(h5py.version.hdf5_version.startswith("1.10"))

70
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
71
        for model_name in ["bert-base-uncased"]:
72
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
73
74
75
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

76
            model = TFAutoModel.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
77
78
79
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertModel)

thomwolf's avatar
thomwolf committed
80
81
82
83
84
85
    @slow
    def test_model_for_pretraining_from_pretrained(self):
        import h5py

        self.assertTrue(h5py.version.hdf5_version.startswith("1.10"))

86
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
thomwolf's avatar
thomwolf committed
87
88
89
90
91
92
93
94
95
        for model_name in ["bert-base-uncased"]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

            model = TFAutoModelForPreTraining.from_pretrained(model_name)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForPreTraining)

96
97
98
99
100
101
102
103
104
105
106
107
    @slow
    def test_model_for_causal_lm(self):
        for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, GPT2Config)

            model = TFAutoModelForCausalLM.from_pretrained(model_name)
            model, loading_info = TFAutoModelForCausalLM.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFGPT2LMHeadModel)

108
    @slow
thomwolf's avatar
thomwolf committed
109
    def test_lmhead_model_from_pretrained(self):
110
        for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
111
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
112
113
114
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

115
            model = TFAutoModelWithLMHead.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
116
117
118
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForMaskedLM)

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    @slow
    def test_model_for_masked_lm(self):
        for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

            model = TFAutoModelForMaskedLM.from_pretrained(model_name)
            model, loading_info = TFAutoModelForMaskedLM.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForMaskedLM)

    @slow
    def test_model_for_encoder_decoder_lm(self):
        for model_name in TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, T5Config)

            model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
            model, loading_info = TFAutoModelForSeq2SeqLM.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFT5ForConditionalGeneration)

143
    @slow
thomwolf's avatar
thomwolf committed
144
    def test_sequence_classification_model_from_pretrained(self):
145
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
146
        for model_name in ["bert-base-uncased"]:
147
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
148
149
150
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

151
            model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
152
153
154
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForSequenceClassification)

155
    @slow
thomwolf's avatar
thomwolf committed
156
    def test_question_answering_model_from_pretrained(self):
157
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
158
        for model_name in ["bert-base-uncased"]:
159
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
160
161
162
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

163
            model = TFAutoModelForQuestionAnswering.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
164
165
166
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForQuestionAnswering)

Julien Chaumond's avatar
Julien Chaumond committed
167
    def test_from_pretrained_identifier(self):
168
        model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
Julien Chaumond's avatar
Julien Chaumond committed
169
        self.assertIsInstance(model, TFBertForMaskedLM)
Julien Chaumond's avatar
Julien Chaumond committed
170
171
        self.assertEqual(model.num_parameters(), 14830)
        self.assertEqual(model.num_parameters(only_trainable=True), 14830)
Julien Chaumond's avatar
Julien Chaumond committed
172
173
174
175
176
177

    def test_from_identifier_from_model_type(self):
        model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
        self.assertIsInstance(model, TFRobertaForMaskedLM)
        self.assertEqual(model.num_parameters(), 14830)
        self.assertEqual(model.num_parameters(only_trainable=True), 14830)
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

    def test_parents_and_children_in_mappings(self):
        # Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
        # by the parents and will return the wrong configuration type when using auto models
        mappings = (
            TF_MODEL_MAPPING,
            TF_MODEL_FOR_PRETRAINING_MAPPING,
            TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
            TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
            TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
            TF_MODEL_WITH_LM_HEAD_MAPPING,
            TF_MODEL_FOR_CAUSAL_LM_MAPPING,
            TF_MODEL_FOR_MASKED_LM_MAPPING,
            TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
        )

        for mapping in mappings:
            mapping = tuple(mapping.items())
            for index, (child_config, child_model) in enumerate(mapping[1:]):
                for parent_config, parent_model in mapping[: index + 1]:
                    with self.subTest(
                        msg="Testing if {} is child of {}".format(child_config.__name__, parent_config.__name__)
                    ):
                        self.assertFalse(issubclass(child_config, parent_config))
                        self.assertFalse(issubclass(child_model, parent_model))