test_modeling_tf_t5.py 6.32 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 Google T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

thomwolf's avatar
thomwolf committed
16

17
18
import unittest

Aymeric Augustin's avatar
Aymeric Augustin committed
19
from transformers import T5Config, is_tf_available
thomwolf's avatar
thomwolf committed
20

21
from .test_configuration_common import ConfigTester
22
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
23
from .utils import CACHE_DIR, require_tf, slow
thomwolf's avatar
thomwolf committed
24
25


26
if is_tf_available():
27
    from transformers.modeling_tf_t5 import TFT5Model, TFT5WithLMHeadModel
thomwolf's avatar
thomwolf committed
28
29


thomwolf's avatar
thomwolf committed
30
@require_tf
31
class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
thomwolf's avatar
thomwolf committed
32

33
34
    is_encoder_decoder = True
    all_model_classes = (TFT5Model, TFT5WithLMHeadModel) if is_tf_available() else ()
thomwolf's avatar
thomwolf committed
35
36

    class TFT5ModelTester(object):
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        def __init__(
            self,
            parent,
            batch_size=13,
            seq_length=7,
            is_training=True,
            use_input_mask=True,
            use_labels=True,
            vocab_size=99,
            n_positions=14,
            hidden_size=32,
            num_hidden_layers=5,
            num_attention_heads=4,
            d_ff=37,
            relative_attention_num_buckets=8,
            dropout_rate=0.1,
            initializer_factor=0.002,
            scope=None,
        ):
thomwolf's avatar
thomwolf committed
56
57
58
59
60
61
62
            self.parent = parent
            self.batch_size = batch_size
            self.seq_length = seq_length
            self.is_training = is_training
            self.use_input_mask = use_input_mask
            self.use_labels = use_labels
            self.vocab_size = vocab_size
63
            self.n_positions = n_positions
thomwolf's avatar
thomwolf committed
64
65
66
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
67
68
69
70
            self.d_ff = d_ff
            self.relative_attention_num_buckets = relative_attention_num_buckets
            self.dropout_rate = dropout_rate
            self.initializer_factor = initializer_factor
thomwolf's avatar
thomwolf committed
71
72
73
74
75
76
77
78
79
80
81
            self.scope = scope

        def prepare_config_and_inputs(self):
            input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

            input_mask = None
            if self.use_input_mask:
                input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)

            token_labels = None
            if self.use_labels:
82
                token_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
thomwolf's avatar
thomwolf committed
83
84

            config = T5Config(
thomwolf's avatar
thomwolf committed
85
                vocab_size=self.vocab_size,
86
87
88
89
90
91
92
93
                n_positions=self.n_positions,
                d_model=self.hidden_size,
                d_ff=self.d_ff,
                d_kv=self.hidden_size // self.num_attention_heads,
                num_layers=self.num_hidden_layers,
                num_heads=self.num_attention_heads,
                relative_attention_num_buckets=self.relative_attention_num_buckets,
                dropout_rate=self.dropout_rate,
94
95
                initializer_factor=self.initializer_factor,
            )
96
97
98
99

            return (config, input_ids, input_mask, token_labels)

        def create_and_check_t5_model(self, config, input_ids, input_mask, token_labels):
thomwolf's avatar
thomwolf committed
100
            model = TFT5Model(config=config)
101
102
103
104
105
            inputs = {
                "encoder_input_ids": input_ids,
                "decoder_input_ids": input_ids,
                "decoder_attention_mask": input_mask,
            }
106
            encoder_output, decoder_output = model(inputs)
thomwolf's avatar
thomwolf committed
107

108
109
110
            encoder_output, decoder_output = model(
                input_ids, decoder_attention_mask=input_mask, encoder_input_ids=input_ids
            )
thomwolf's avatar
thomwolf committed
111
112

            result = {
113
114
                "encoder_output": encoder_output.numpy(),
                "decoder_output": decoder_output.numpy(),
thomwolf's avatar
thomwolf committed
115
116
            }
            self.parent.assertListEqual(
117
118
                list(result["encoder_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
            )
119
            self.parent.assertListEqual(
120
121
                list(result["decoder_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
            )
thomwolf's avatar
thomwolf committed
122

123
        def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
thomwolf's avatar
thomwolf committed
124
            model = TFT5WithLMHeadModel(config=config)
125
126
127
128
129
            inputs = {
                "encoder_input_ids": input_ids,
                "decoder_input_ids": input_ids,
                "decoder_attention_mask": input_mask,
            }
130
            prediction_scores, decoder_output = model(inputs)
thomwolf's avatar
thomwolf committed
131
132
133
134
            result = {
                "prediction_scores": prediction_scores.numpy(),
            }
            self.parent.assertListEqual(
135
136
                list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
            )
thomwolf's avatar
thomwolf committed
137
138
139

        def prepare_config_and_inputs_for_common(self):
            config_and_inputs = self.prepare_config_and_inputs()
140
            (config, input_ids, input_mask, token_labels) = config_and_inputs
141
142
143
144
145
            inputs_dict = {
                "encoder_input_ids": input_ids,
                "decoder_input_ids": input_ids,
                "decoder_attention_mask": input_mask,
            }
thomwolf's avatar
thomwolf committed
146
147
148
149
            return config, inputs_dict

    def setUp(self):
        self.model_tester = TFT5ModelTest.TFT5ModelTester(self)
150
        self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
thomwolf's avatar
thomwolf committed
151
152
153
154
155
156
157
158
159
160
161
162

    def test_config(self):
        self.config_tester.run_common_tests()

    def test_t5_model(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_t5_model(*config_and_inputs)

    def test_with_lm_head(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)

thomwolf's avatar
thomwolf committed
163
    @slow
thomwolf's avatar
thomwolf committed
164
    def test_model_from_pretrained(self):
165
        for model_name in ["t5-small"]:
166
            model = TFT5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
thomwolf's avatar
thomwolf committed
167
            self.assertIsNotNone(model)