test_modeling_t5.py 10.3 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 Google T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

thomwolf's avatar
thomwolf committed
16

17
18
import unittest

thomwolf's avatar
thomwolf committed
19
20
from transformers import is_torch_available

21
from .test_configuration_common import ConfigTester
22
from .test_modeling_common import ModelTesterMixin, ids_tensor
23
from .utils import CACHE_DIR, require_torch, slow, torch_device
thomwolf's avatar
thomwolf committed
24

Aymeric Augustin's avatar
Aymeric Augustin committed
25

thomwolf's avatar
thomwolf committed
26
if is_torch_available():
27
    import torch
28
    from transformers import T5Config, T5Model, T5ForConditionalGeneration
thomwolf's avatar
thomwolf committed
29
30
31
    from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP


thomwolf's avatar
thomwolf committed
32
@require_torch
33
class T5ModelTest(ModelTesterMixin, unittest.TestCase):
thomwolf's avatar
thomwolf committed
34

35
36
    all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
    all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
thomwolf's avatar
thomwolf committed
37
38
39
40
41
42
    test_pruning = False
    test_torchscript = False
    test_resize_embeddings = False
    is_encoder_decoder = True

    class T5ModelTester(object):
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        def __init__(
            self,
            parent,
            batch_size=13,
            encoder_seq_length=7,
            decoder_seq_length=9,
            is_training=True,
            use_attention_mask=True,
            use_labels=True,
            vocab_size=99,
            n_positions=14,
            hidden_size=32,
            num_hidden_layers=5,
            num_attention_heads=4,
            d_ff=37,
            relative_attention_num_buckets=8,
            dropout_rate=0.1,
            initializer_factor=0.002,
61
            eos_token_id=1,
62
            pad_token_id=0,
63
            decoder_start_token_id=0,
64
65
            scope=None,
        ):
thomwolf's avatar
thomwolf committed
66
67
            self.parent = parent
            self.batch_size = batch_size
thomwolf's avatar
thomwolf committed
68
69
            self.encoder_seq_length = encoder_seq_length
            self.decoder_seq_length = decoder_seq_length
thomwolf's avatar
thomwolf committed
70
            self.is_training = is_training
thomwolf's avatar
thomwolf committed
71
            self.use_attention_mask = use_attention_mask
thomwolf's avatar
thomwolf committed
72
73
74
75
76
77
78
79
80
            self.use_labels = use_labels
            self.vocab_size = vocab_size
            self.n_positions = n_positions
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.d_ff = d_ff
            self.relative_attention_num_buckets = relative_attention_num_buckets
            self.dropout_rate = dropout_rate
81
            self.initializer_factor = initializer_factor
thomwolf's avatar
thomwolf committed
82
            self.scope = scope
83
            self.eos_token_id = eos_token_id
84
            self.pad_token_id = pad_token_id
85
            self.decoder_start_token_id = decoder_start_token_id
thomwolf's avatar
thomwolf committed
86
87

        def prepare_config_and_inputs(self):
88
            input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
thomwolf's avatar
thomwolf committed
89
            decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
thomwolf's avatar
thomwolf committed
90

91
            attention_mask = None
thomwolf's avatar
thomwolf committed
92
93
            decoder_attention_mask = None
            if self.use_attention_mask:
94
                attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
thomwolf's avatar
thomwolf committed
95
                decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
thomwolf's avatar
thomwolf committed
96

97
            lm_labels = None
thomwolf's avatar
thomwolf committed
98
            if self.use_labels:
99
                lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
thomwolf's avatar
thomwolf committed
100
101

            config = T5Config(
thomwolf's avatar
thomwolf committed
102
                vocab_size=self.vocab_size,
thomwolf's avatar
thomwolf committed
103
104
105
                n_positions=self.n_positions,
                d_model=self.hidden_size,
                d_ff=self.d_ff,
106
                d_kv=self.hidden_size // self.num_attention_heads,
thomwolf's avatar
thomwolf committed
107
108
109
110
                num_layers=self.num_hidden_layers,
                num_heads=self.num_attention_heads,
                relative_attention_num_buckets=self.relative_attention_num_buckets,
                dropout_rate=self.dropout_rate,
111
                initializer_factor=self.initializer_factor,
112
                eos_token_id=self.eos_token_id,
113
114
                bos_token_id=self.pad_token_id,
                pad_token_id=self.pad_token_id,
115
                decoder_start_token_id=self.decoder_start_token_id,
116
117
118
119
            )

            return (
                config,
120
                input_ids,
121
                decoder_input_ids,
122
                attention_mask,
123
                decoder_attention_mask,
124
                lm_labels,
125
            )
thomwolf's avatar
thomwolf committed
126
127

        def check_loss_output(self, result):
128
129
            self.parent.assertListEqual(list(result["loss"].size()), [])

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        def check_prepare_lm_labels_via_shift_left(
            self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
        ):
            model = T5Model(config=config)
            model.to(torch_device)
            model.eval()

            # make sure that lm_labels are correctly padded from the right
            lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id)

            # add casaul pad token mask
            triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not()
            lm_labels.masked_fill_(triangular_mask, self.pad_token_id)
            decoder_input_ids = model._shift_right(lm_labels)

            for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)):
                # first item
                self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id)
                if i < decoder_input_ids_slice.shape[-1]:
                    if i < decoder_input_ids.shape[-1] - 1:
                        # items before diagonal
                        self.parent.assertListEqual(
                            decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist()
                        )
                    # pad items after diagonal
                    if i < decoder_input_ids.shape[-1] - 2:
                        self.parent.assertListEqual(
                            decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist()
                        )
                else:
                    # all items after square
                    self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())

163
        def create_and_check_t5_model(
164
            self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
165
        ):
thomwolf's avatar
thomwolf committed
166
            model = T5Model(config=config)
167
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
168
            model.eval()
169
            decoder_output, encoder_output = model(
170
                input_ids=input_ids,
171
                decoder_input_ids=decoder_input_ids,
172
                attention_mask=attention_mask,
173
174
                decoder_attention_mask=decoder_attention_mask,
            )
175
            decoder_output, encoder_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
thomwolf's avatar
thomwolf committed
176
177
178
179
180
181

            result = {
                "encoder_output": encoder_output,
                "decoder_output": decoder_output,
            }
            self.parent.assertListEqual(
182
183
                list(result["encoder_output"].size()), [self.batch_size, self.encoder_seq_length, self.hidden_size]
            )
thomwolf's avatar
thomwolf committed
184
            self.parent.assertListEqual(
185
186
187
188
                list(result["decoder_output"].size()), [self.batch_size, self.decoder_seq_length, self.hidden_size]
            )

        def create_and_check_t5_with_lm_head(
189
            self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
190
        ):
191
            model = T5ForConditionalGeneration(config=config)
192
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
193
            model.eval()
194
            outputs = model(
195
                input_ids=input_ids,
196
197
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
198
                lm_labels=lm_labels,
199
            )
Sam Shleifer's avatar
Sam Shleifer committed
200
201
            loss, prediction_scores, encoder_features = outputs
            self.parent.assertEqual(len(outputs), 3)
thomwolf's avatar
thomwolf committed
202
203
204
205
206
            result = {
                "loss": loss,
                "prediction_scores": prediction_scores,
            }
            self.parent.assertListEqual(
207
208
                list(result["prediction_scores"].size()), [self.batch_size, self.decoder_seq_length, self.vocab_size]
            )
thomwolf's avatar
thomwolf committed
209
210
211
212
            self.check_loss_output(result)

        def prepare_config_and_inputs_for_common(self):
            config_and_inputs = self.prepare_config_and_inputs()
213
214
            (
                config,
215
                input_ids,
216
                decoder_input_ids,
217
                attention_mask,
218
                decoder_attention_mask,
219
                lm_labels,
220
            ) = config_and_inputs
221

222
            inputs_dict = {
223
224
                "input_ids": input_ids,
                "attention_mask": attention_mask,
225
226
227
                "decoder_input_ids": decoder_input_ids,
                "decoder_attention_mask": decoder_attention_mask,
            }
thomwolf's avatar
thomwolf committed
228
229
230
231
232
233
234
235
236
            return config, inputs_dict

    def setUp(self):
        self.model_tester = T5ModelTest.T5ModelTester(self)
        self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)

    def test_config(self):
        self.config_tester.run_common_tests()

237
238
239
240
    def test_shift_right(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs)

thomwolf's avatar
thomwolf committed
241
242
243
244
245
246
247
248
    def test_t5_model(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_t5_model(*config_and_inputs)

    def test_with_lm_head(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)

thomwolf's avatar
thomwolf committed
249
    @slow
thomwolf's avatar
thomwolf committed
250
251
    def test_model_from_pretrained(self):
        for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
252
            model = T5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
thomwolf's avatar
thomwolf committed
253
            self.assertIsNotNone(model)