test_modeling_tf_xlnet.py 14.4 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15
from __future__ import absolute_import, division, print_function
thomwolf's avatar
thomwolf committed
16
17

import random
18
import unittest
thomwolf's avatar
thomwolf committed
19

20
from transformers import XLNetConfig, is_tf_available
thomwolf's avatar
thomwolf committed
21

22
from .test_configuration_common import ConfigTester
23
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
Aymeric Augustin's avatar
Aymeric Augustin committed
24
25
26
from .utils import CACHE_DIR, require_tf, slow


thomwolf's avatar
thomwolf committed
27
28
29
if is_tf_available():
    import tensorflow as tf

30
31
32
33
34
35
36
37
38
    from transformers.modeling_tf_xlnet import (
        TFXLNetModel,
        TFXLNetLMHeadModel,
        TFXLNetForSequenceClassification,
        TFXLNetForTokenClassification,
        TFXLNetForQuestionAnsweringSimple,
        TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
    )

39
40

@require_tf
41
class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
thomwolf's avatar
thomwolf committed
42

43
44
45
46
47
48
49
50
51
52
53
    all_model_classes = (
        (
            TFXLNetModel,
            TFXLNetLMHeadModel,
            TFXLNetForSequenceClassification,
            TFXLNetForTokenClassification,
            TFXLNetForQuestionAnsweringSimple,
        )
        if is_tf_available()
        else ()
    )
thomwolf's avatar
thomwolf committed
54
55
56
    test_pruning = False

    class TFXLNetModelTester(object):
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        def __init__(
            self,
            parent,
            batch_size=13,
            seq_length=7,
            mem_len=10,
            clamp_len=-1,
            reuse_len=15,
            is_training=True,
            use_labels=True,
            vocab_size=99,
            cutoffs=[10, 50, 80],
            hidden_size=32,
            num_attention_heads=4,
            d_inner=128,
            num_hidden_layers=5,
            type_sequence_label_size=2,
            untie_r=True,
            bi_data=False,
            same_length=False,
            initializer_range=0.05,
            seed=1,
            type_vocab_size=2,
        ):
thomwolf's avatar
thomwolf committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
            self.parent = parent
            self.batch_size = batch_size
            self.seq_length = seq_length
            self.mem_len = mem_len
            # self.key_len = seq_length + mem_len
            self.clamp_len = clamp_len
            self.reuse_len = reuse_len
            self.is_training = is_training
            self.use_labels = use_labels
            self.vocab_size = vocab_size
            self.cutoffs = cutoffs
            self.hidden_size = hidden_size
            self.num_attention_heads = num_attention_heads
            self.d_inner = d_inner
            self.num_hidden_layers = num_hidden_layers
            self.bi_data = bi_data
            self.untie_r = untie_r
            self.same_length = same_length
            self.initializer_range = initializer_range
            self.seed = seed
            self.type_vocab_size = type_vocab_size
            self.type_sequence_label_size = type_sequence_label_size

        def prepare_config_and_inputs(self):
            input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
            input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
            segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
            input_mask = ids_tensor([self.batch_size, self.seq_length], 2, dtype=tf.float32)

            input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
            perm_mask = tf.zeros((self.batch_size, self.seq_length + 1, self.seq_length), dtype=tf.float32)
            perm_mask_last = tf.ones((self.batch_size, self.seq_length + 1, 1), dtype=tf.float32)
            perm_mask = tf.concat([perm_mask, perm_mask_last], axis=-1)
            # perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token
thomwolf's avatar
thomwolf committed
115
116
            target_mapping = tf.zeros((self.batch_size, 1, self.seq_length), dtype=tf.float32)
            target_mapping_last = tf.ones((self.batch_size, 1, 1), dtype=tf.float32)
thomwolf's avatar
thomwolf committed
117
118
119
120
121
122
123
124
125
126
127
128
            target_mapping = tf.concat([target_mapping, target_mapping_last], axis=-1)
            # target_mapping[:, 0, -1] = 1.0  # predict last token

            sequence_labels = None
            lm_labels = None
            is_impossible_labels = None
            if self.use_labels:
                lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
                sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
                is_impossible_labels = ids_tensor([self.batch_size], 2, dtype=tf.float32)

            config = XLNetConfig(
thomwolf's avatar
thomwolf committed
129
                vocab_size=self.vocab_size,
thomwolf's avatar
thomwolf committed
130
131
132
133
134
135
136
137
138
139
140
                d_model=self.hidden_size,
                n_head=self.num_attention_heads,
                d_inner=self.d_inner,
                n_layer=self.num_hidden_layers,
                untie_r=self.untie_r,
                mem_len=self.mem_len,
                clamp_len=self.clamp_len,
                same_length=self.same_length,
                reuse_len=self.reuse_len,
                bi_data=self.bi_data,
                initializer_range=self.initializer_range,
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
                num_labels=self.type_sequence_label_size,
            )

            return (
                config,
                input_ids_1,
                input_ids_2,
                input_ids_q,
                perm_mask,
                input_mask,
                target_mapping,
                segment_ids,
                lm_labels,
                sequence_labels,
                is_impossible_labels,
            )
thomwolf's avatar
thomwolf committed
157
158
159
160
161

        def set_seed(self):
            random.seed(self.seed)
            tf.random.set_seed(self.seed)

162
163
164
165
166
167
168
169
170
171
172
173
174
175
        def create_and_check_xlnet_base_model(
            self,
            config,
            input_ids_1,
            input_ids_2,
            input_ids_q,
            perm_mask,
            input_mask,
            target_mapping,
            segment_ids,
            lm_labels,
            sequence_labels,
            is_impossible_labels,
        ):
thomwolf's avatar
thomwolf committed
176
177
            model = TFXLNetModel(config)

178
            inputs = {"input_ids": input_ids_1, "input_mask": input_mask, "token_type_ids": segment_ids}
thomwolf's avatar
thomwolf committed
179
180
181

            _, _ = model(inputs)

thomwolf's avatar
thomwolf committed
182
            inputs = [input_ids_1, input_mask]
thomwolf's avatar
thomwolf committed
183
184
185
186

            outputs, mems_1 = model(inputs)

            result = {
thomwolf's avatar
thomwolf committed
187
                "mems_1": [mem.numpy() for mem in mems_1],
thomwolf's avatar
thomwolf committed
188
189
190
                "outputs": outputs.numpy(),
            }

thomwolf's avatar
thomwolf committed
191
192
            config.mem_len = 0
            model = TFXLNetModel(config)
193
194
195
            no_mems_outputs = model(inputs)
            self.parent.assertEqual(len(no_mems_outputs), 1)

thomwolf's avatar
thomwolf committed
196
            self.parent.assertListEqual(
197
198
                list(result["outputs"].shape), [self.batch_size, self.seq_length, self.hidden_size]
            )
thomwolf's avatar
thomwolf committed
199
200
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_1"]),
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
                [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
            )

        def create_and_check_xlnet_lm_head(
            self,
            config,
            input_ids_1,
            input_ids_2,
            input_ids_q,
            perm_mask,
            input_mask,
            target_mapping,
            segment_ids,
            lm_labels,
            sequence_labels,
            is_impossible_labels,
        ):
218
219
            model = TFXLNetLMHeadModel(config)

220
            inputs_1 = {"input_ids": input_ids_1, "token_type_ids": segment_ids}
221
222
223

            all_logits_1, mems_1 = model(inputs_1)

224
            inputs_2 = {"input_ids": input_ids_2, "mems": mems_1, "token_type_ids": segment_ids}
225
226
227

            all_logits_2, mems_2 = model(inputs_2)

228
            inputs_3 = {"input_ids": input_ids_q, "perm_mask": perm_mask, "target_mapping": target_mapping}
229
230
231
232
233
234
235
236
237
238
239

            logits, _ = model(inputs_3)

            result = {
                "mems_1": [mem.numpy() for mem in mems_1],
                "all_logits_1": all_logits_1.numpy(),
                "mems_2": [mem.numpy() for mem in mems_2],
                "all_logits_2": all_logits_2.numpy(),
            }

            self.parent.assertListEqual(
240
241
                list(result["all_logits_1"].shape), [self.batch_size, self.seq_length, self.vocab_size]
            )
242
243
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_1"]),
244
245
                [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
            )
246
247

            self.parent.assertListEqual(
248
249
                list(result["all_logits_2"].shape), [self.batch_size, self.seq_length, self.vocab_size]
            )
250
251
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_2"]),
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
                [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
            )

        def create_and_check_xlnet_qa(
            self,
            config,
            input_ids_1,
            input_ids_2,
            input_ids_q,
            perm_mask,
            input_mask,
            target_mapping,
            segment_ids,
            lm_labels,
            sequence_labels,
            is_impossible_labels,
        ):
269
270
            model = TFXLNetForQuestionAnsweringSimple(config)

271
            inputs = {"input_ids": input_ids_1, "attention_mask": input_mask, "token_type_ids": segment_ids}
272
273
274
275
276
277
278
279
            start_logits, end_logits, mems = model(inputs)

            result = {
                "start_logits": start_logits.numpy(),
                "end_logits": end_logits.numpy(),
                "mems": [m.numpy() for m in mems],
            }

280
281
            self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
            self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
282
283
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems"]),
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
                [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
            )

        def create_and_check_xlnet_sequence_classif(
            self,
            config,
            input_ids_1,
            input_ids_2,
            input_ids_q,
            perm_mask,
            input_mask,
            target_mapping,
            segment_ids,
            lm_labels,
            sequence_labels,
            is_impossible_labels,
        ):
301
302
303
304
305
306
307
308
309
            model = TFXLNetForSequenceClassification(config)

            logits, mems_1 = model(input_ids_1)

            result = {
                "mems_1": [mem.numpy() for mem in mems_1],
                "logits": logits.numpy(),
            }

310
            self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.type_sequence_label_size])
311
312
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_1"]),
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
                [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
            )

        def create_and_check_xlnet_for_token_classification(
            self,
            config,
            input_ids_1,
            input_ids_2,
            input_ids_q,
            perm_mask,
            input_mask,
            target_mapping,
            segment_ids,
            lm_labels,
            sequence_labels,
            is_impossible_labels,
        ):
330
331
            config.num_labels = input_ids_1.shape[1]
            model = TFXLNetForTokenClassification(config)
332
333
334
335
336
            inputs = {
                "input_ids": input_ids_1,
                "attention_mask": input_mask,
                # 'token_type_ids': token_type_ids
            }
337
338
339
340
341
342
            logits, mems_1 = model(inputs)
            result = {
                "mems_1": [mem.numpy() for mem in mems_1],
                "logits": logits.numpy(),
            }
            self.parent.assertListEqual(
343
344
                list(result["logits"].shape), [self.batch_size, self.seq_length, config.num_labels]
            )
345
346
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_1"]),
347
348
                [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
            )
349

thomwolf's avatar
thomwolf committed
350
351
        def prepare_config_and_inputs_for_common(self):
            config_and_inputs = self.prepare_config_and_inputs()
352
353
354
355
356
357
358
359
360
361
362
363
364
365
            (
                config,
                input_ids_1,
                input_ids_2,
                input_ids_q,
                perm_mask,
                input_mask,
                target_mapping,
                segment_ids,
                lm_labels,
                sequence_labels,
                is_impossible_labels,
            ) = config_and_inputs
            inputs_dict = {"input_ids": input_ids_1}
thomwolf's avatar
thomwolf committed
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
            return config, inputs_dict

    def setUp(self):
        self.model_tester = TFXLNetModelTest.TFXLNetModelTester(self)
        self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)

    def test_config(self):
        self.config_tester.run_common_tests()

    def test_xlnet_base_model(self):
        self.model_tester.set_seed()
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)

    def test_xlnet_lm_head(self):
        self.model_tester.set_seed()
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
383
        self.model_tester.create_and_check_xlnet_lm_head(*config_and_inputs)
thomwolf's avatar
thomwolf committed
384
385
386
387
388
389

    def test_xlnet_sequence_classif(self):
        self.model_tester.set_seed()
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)

390
391
392
393
    def test_xlnet_token_classification(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlnet_for_token_classification(*config_and_inputs)

thomwolf's avatar
thomwolf committed
394
395
396
397
398
    def test_xlnet_qa(self):
        self.model_tester.set_seed()
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlnet_qa(*config_and_inputs)

399
    @slow
thomwolf's avatar
thomwolf committed
400
401
    def test_model_from_pretrained(self):
        for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
402
            model = TFXLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
thomwolf's avatar
thomwolf committed
403
            self.assertIsNotNone(model)