test_modeling_tf_t5.py 6.67 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 Google T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

thomwolf's avatar
thomwolf committed
16

17
18
import unittest

Aymeric Augustin's avatar
Aymeric Augustin committed
19
from transformers import T5Config, is_tf_available
thomwolf's avatar
thomwolf committed
20

21
from .test_configuration_common import ConfigTester
22
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
23
from .utils import CACHE_DIR, require_tf, slow
thomwolf's avatar
thomwolf committed
24
25


26
if is_tf_available():
27
    from transformers.modeling_tf_t5 import TFT5Model, TFT5ForConditionalGeneration
thomwolf's avatar
thomwolf committed
28
29


thomwolf's avatar
thomwolf committed
30
@require_tf
31
class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
thomwolf's avatar
thomwolf committed
32

33
    is_encoder_decoder = True
34
35
    all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
    all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
thomwolf's avatar
thomwolf committed
36
37

    class TFT5ModelTester(object):
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        def __init__(
            self,
            parent,
            batch_size=13,
            seq_length=7,
            is_training=True,
            use_input_mask=True,
            use_labels=True,
            vocab_size=99,
            n_positions=14,
            hidden_size=32,
            num_hidden_layers=5,
            num_attention_heads=4,
            d_ff=37,
            relative_attention_num_buckets=8,
            dropout_rate=0.1,
            initializer_factor=0.002,
55
            eos_token_id=1,
56
            pad_token_id=0,
57
58
            scope=None,
        ):
thomwolf's avatar
thomwolf committed
59
60
61
62
63
64
65
            self.parent = parent
            self.batch_size = batch_size
            self.seq_length = seq_length
            self.is_training = is_training
            self.use_input_mask = use_input_mask
            self.use_labels = use_labels
            self.vocab_size = vocab_size
66
            self.n_positions = n_positions
thomwolf's avatar
thomwolf committed
67
68
69
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
70
71
72
73
            self.d_ff = d_ff
            self.relative_attention_num_buckets = relative_attention_num_buckets
            self.dropout_rate = dropout_rate
            self.initializer_factor = initializer_factor
74
            self.eos_token_id = eos_token_id
75
            self.pad_token_id = pad_token_id
thomwolf's avatar
thomwolf committed
76
77
78
79
80
81
82
83
84
85
86
            self.scope = scope

        def prepare_config_and_inputs(self):
            input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

            input_mask = None
            if self.use_input_mask:
                input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)

            token_labels = None
            if self.use_labels:
87
                token_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
thomwolf's avatar
thomwolf committed
88
89

            config = T5Config(
thomwolf's avatar
thomwolf committed
90
                vocab_size=self.vocab_size,
91
92
93
94
95
96
97
98
                n_positions=self.n_positions,
                d_model=self.hidden_size,
                d_ff=self.d_ff,
                d_kv=self.hidden_size // self.num_attention_heads,
                num_layers=self.num_hidden_layers,
                num_heads=self.num_attention_heads,
                relative_attention_num_buckets=self.relative_attention_num_buckets,
                dropout_rate=self.dropout_rate,
99
                initializer_factor=self.initializer_factor,
100
                eos_token_id=self.eos_token_id,
101
102
                bos_token_id=self.pad_token_id,
                pad_token_id=self.pad_token_id,
103
            )
104
105
106
107

            return (config, input_ids, input_mask, token_labels)

        def create_and_check_t5_model(self, config, input_ids, input_mask, token_labels):
thomwolf's avatar
thomwolf committed
108
            model = TFT5Model(config=config)
109
            inputs = {
110
                "input_ids": input_ids,
111
112
113
                "decoder_input_ids": input_ids,
                "decoder_attention_mask": input_mask,
            }
114
            encoder_output, decoder_output = model(inputs)
thomwolf's avatar
thomwolf committed
115

116
            encoder_output, decoder_output = model(input_ids, decoder_attention_mask=input_mask, input_ids=input_ids)
thomwolf's avatar
thomwolf committed
117
118

            result = {
119
120
                "encoder_output": encoder_output.numpy(),
                "decoder_output": decoder_output.numpy(),
thomwolf's avatar
thomwolf committed
121
122
            }
            self.parent.assertListEqual(
123
124
                list(result["encoder_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
            )
125
            self.parent.assertListEqual(
126
127
                list(result["decoder_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
            )
thomwolf's avatar
thomwolf committed
128

129
        def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
130
131
132
            model = TFT5ForConditionalGeneration(config=config)
            inputs_dict = {
                "input_ids": input_ids,
133
134
135
                "decoder_input_ids": input_ids,
                "decoder_attention_mask": input_mask,
            }
136
137
138

            prediction_scores, decoder_output = model(inputs_dict)

thomwolf's avatar
thomwolf committed
139
140
141
142
            result = {
                "prediction_scores": prediction_scores.numpy(),
            }
            self.parent.assertListEqual(
143
144
                list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
            )
thomwolf's avatar
thomwolf committed
145
146
147

        def prepare_config_and_inputs_for_common(self):
            config_and_inputs = self.prepare_config_and_inputs()
148
            (config, input_ids, input_mask, token_labels) = config_and_inputs
149
            inputs_dict = {
150
                "input_ids": input_ids,
151
152
153
                "decoder_input_ids": input_ids,
                "decoder_attention_mask": input_mask,
            }
thomwolf's avatar
thomwolf committed
154
155
156
157
            return config, inputs_dict

    def setUp(self):
        self.model_tester = TFT5ModelTest.TFT5ModelTester(self)
158
        self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
thomwolf's avatar
thomwolf committed
159
160
161
162
163
164
165
166
167
168
169
170

    def test_config(self):
        self.config_tester.run_common_tests()

    def test_t5_model(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_t5_model(*config_and_inputs)

    def test_with_lm_head(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)

thomwolf's avatar
thomwolf committed
171
    @slow
thomwolf's avatar
thomwolf committed
172
    def test_model_from_pretrained(self):
173
        for model_name in ["t5-small"]:
174
            model = TFT5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
thomwolf's avatar
thomwolf committed
175
            self.assertIsNotNone(model)