test_modeling_tf_auto.py 4.61 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

thomwolf's avatar
thomwolf committed
16

Aymeric Augustin's avatar
Aymeric Augustin committed
17
import unittest
thomwolf's avatar
thomwolf committed
18

19
from transformers import is_tf_available
20
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_tf, slow
Aymeric Augustin's avatar
Aymeric Augustin committed
21

22

23
if is_tf_available():
24
25
26
27
28
    from transformers import (
        AutoConfig,
        BertConfig,
        TFAutoModel,
        TFBertModel,
thomwolf's avatar
thomwolf committed
29
30
        TFAutoModelForPreTraining,
        TFBertForPreTraining,
31
32
        TFAutoModelWithLMHead,
        TFBertForMaskedLM,
Julien Chaumond's avatar
Julien Chaumond committed
33
        TFRobertaForMaskedLM,
34
35
36
37
38
        TFAutoModelForSequenceClassification,
        TFBertForSequenceClassification,
        TFAutoModelForQuestionAnswering,
        TFBertForQuestionAnswering,
    )
thomwolf's avatar
thomwolf committed
39
40


41
@require_tf
thomwolf's avatar
thomwolf committed
42
class TFAutoModelTest(unittest.TestCase):
43
    @slow
thomwolf's avatar
thomwolf committed
44
    def test_model_from_pretrained(self):
thomwolf's avatar
thomwolf committed
45
        import h5py
46

thomwolf's avatar
thomwolf committed
47
48
        self.assertTrue(h5py.version.hdf5_version.startswith("1.10"))

49
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
50
        for model_name in ["bert-base-uncased"]:
51
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
52
53
54
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

55
            model = TFAutoModel.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
56
57
58
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertModel)

thomwolf's avatar
thomwolf committed
59
60
61
62
63
64
    @slow
    def test_model_for_pretraining_from_pretrained(self):
        import h5py

        self.assertTrue(h5py.version.hdf5_version.startswith("1.10"))

65
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
thomwolf's avatar
thomwolf committed
66
67
68
69
70
71
72
73
74
        for model_name in ["bert-base-uncased"]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

            model = TFAutoModelForPreTraining.from_pretrained(model_name)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForPreTraining)

75
    @slow
thomwolf's avatar
thomwolf committed
76
    def test_lmhead_model_from_pretrained(self):
77
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
78
        for model_name in ["bert-base-uncased"]:
79
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
80
81
82
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

83
            model = TFAutoModelWithLMHead.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
84
85
86
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForMaskedLM)

87
    @slow
thomwolf's avatar
thomwolf committed
88
    def test_sequence_classification_model_from_pretrained(self):
89
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
90
        for model_name in ["bert-base-uncased"]:
91
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
92
93
94
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

95
            model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
96
97
98
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForSequenceClassification)

99
    @slow
thomwolf's avatar
thomwolf committed
100
    def test_question_answering_model_from_pretrained(self):
101
        # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
102
        for model_name in ["bert-base-uncased"]:
103
            config = AutoConfig.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
104
105
106
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

107
            model = TFAutoModelForQuestionAnswering.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
108
109
110
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForQuestionAnswering)

Julien Chaumond's avatar
Julien Chaumond committed
111
    def test_from_pretrained_identifier(self):
112
        model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
Julien Chaumond's avatar
Julien Chaumond committed
113
        self.assertIsInstance(model, TFBertForMaskedLM)
Julien Chaumond's avatar
Julien Chaumond committed
114
115
        self.assertEqual(model.num_parameters(), 14830)
        self.assertEqual(model.num_parameters(only_trainable=True), 14830)
Julien Chaumond's avatar
Julien Chaumond committed
116
117
118
119
120
121

    def test_from_identifier_from_model_type(self):
        model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
        self.assertIsInstance(model, TFRobertaForMaskedLM)
        self.assertEqual(model.num_parameters(), 14830)
        self.assertEqual(model.num_parameters(only_trainable=True), 14830)