test_modeling_tf_xlnet.py 14.4 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15
from __future__ import absolute_import, division, print_function
thomwolf's avatar
thomwolf committed
16
17
18

import random

19
from transformers import XLNetConfig, is_tf_available
thomwolf's avatar
thomwolf committed
20

21
22
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor
Aymeric Augustin's avatar
Aymeric Augustin committed
23
24
25
from .utils import CACHE_DIR, require_tf, slow


thomwolf's avatar
thomwolf committed
26
27
28
if is_tf_available():
    import tensorflow as tf

29
30
31
32
33
34
35
36
37
    from transformers.modeling_tf_xlnet import (
        TFXLNetModel,
        TFXLNetLMHeadModel,
        TFXLNetForSequenceClassification,
        TFXLNetForTokenClassification,
        TFXLNetForQuestionAnsweringSimple,
        TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
    )

38
39

@require_tf
thomwolf's avatar
thomwolf committed
40
41
class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):

42
43
44
45
46
47
48
49
50
51
52
    all_model_classes = (
        (
            TFXLNetModel,
            TFXLNetLMHeadModel,
            TFXLNetForSequenceClassification,
            TFXLNetForTokenClassification,
            TFXLNetForQuestionAnsweringSimple,
        )
        if is_tf_available()
        else ()
    )
thomwolf's avatar
thomwolf committed
53
54
55
    test_pruning = False

    class TFXLNetModelTester(object):
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        def __init__(
            self,
            parent,
            batch_size=13,
            seq_length=7,
            mem_len=10,
            clamp_len=-1,
            reuse_len=15,
            is_training=True,
            use_labels=True,
            vocab_size=99,
            cutoffs=[10, 50, 80],
            hidden_size=32,
            num_attention_heads=4,
            d_inner=128,
            num_hidden_layers=5,
            type_sequence_label_size=2,
            untie_r=True,
            bi_data=False,
            same_length=False,
            initializer_range=0.05,
            seed=1,
            type_vocab_size=2,
        ):
thomwolf's avatar
thomwolf committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            self.parent = parent
            self.batch_size = batch_size
            self.seq_length = seq_length
            self.mem_len = mem_len
            # self.key_len = seq_length + mem_len
            self.clamp_len = clamp_len
            self.reuse_len = reuse_len
            self.is_training = is_training
            self.use_labels = use_labels
            self.vocab_size = vocab_size
            self.cutoffs = cutoffs
            self.hidden_size = hidden_size
            self.num_attention_heads = num_attention_heads
            self.d_inner = d_inner
            self.num_hidden_layers = num_hidden_layers
            self.bi_data = bi_data
            self.untie_r = untie_r
            self.same_length = same_length
            self.initializer_range = initializer_range
            self.seed = seed
            self.type_vocab_size = type_vocab_size
            self.type_sequence_label_size = type_sequence_label_size

        def prepare_config_and_inputs(self):
            input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
            input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
            segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
            input_mask = ids_tensor([self.batch_size, self.seq_length], 2, dtype=tf.float32)

            input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
            perm_mask = tf.zeros((self.batch_size, self.seq_length + 1, self.seq_length), dtype=tf.float32)
            perm_mask_last = tf.ones((self.batch_size, self.seq_length + 1, 1), dtype=tf.float32)
            perm_mask = tf.concat([perm_mask, perm_mask_last], axis=-1)
            # perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token
thomwolf's avatar
thomwolf committed
114
115
            target_mapping = tf.zeros((self.batch_size, 1, self.seq_length), dtype=tf.float32)
            target_mapping_last = tf.ones((self.batch_size, 1, 1), dtype=tf.float32)
thomwolf's avatar
thomwolf committed
116
117
118
119
120
121
122
123
124
125
126
127
            target_mapping = tf.concat([target_mapping, target_mapping_last], axis=-1)
            # target_mapping[:, 0, -1] = 1.0  # predict last token

            sequence_labels = None
            lm_labels = None
            is_impossible_labels = None
            if self.use_labels:
                lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
                sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
                is_impossible_labels = ids_tensor([self.batch_size], 2, dtype=tf.float32)

            config = XLNetConfig(
thomwolf's avatar
thomwolf committed
128
                vocab_size=self.vocab_size,
thomwolf's avatar
thomwolf committed
129
130
131
132
133
134
135
136
137
138
139
                d_model=self.hidden_size,
                n_head=self.num_attention_heads,
                d_inner=self.d_inner,
                n_layer=self.num_hidden_layers,
                untie_r=self.untie_r,
                mem_len=self.mem_len,
                clamp_len=self.clamp_len,
                same_length=self.same_length,
                reuse_len=self.reuse_len,
                bi_data=self.bi_data,
                initializer_range=self.initializer_range,
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
                num_labels=self.type_sequence_label_size,
            )

            return (
                config,
                input_ids_1,
                input_ids_2,
                input_ids_q,
                perm_mask,
                input_mask,
                target_mapping,
                segment_ids,
                lm_labels,
                sequence_labels,
                is_impossible_labels,
            )
thomwolf's avatar
thomwolf committed
156
157
158
159
160

        def set_seed(self):
            random.seed(self.seed)
            tf.random.set_seed(self.seed)

161
162
163
164
165
166
167
168
169
170
171
172
173
174
        def create_and_check_xlnet_base_model(
            self,
            config,
            input_ids_1,
            input_ids_2,
            input_ids_q,
            perm_mask,
            input_mask,
            target_mapping,
            segment_ids,
            lm_labels,
            sequence_labels,
            is_impossible_labels,
        ):
thomwolf's avatar
thomwolf committed
175
176
            model = TFXLNetModel(config)

177
            inputs = {"input_ids": input_ids_1, "input_mask": input_mask, "token_type_ids": segment_ids}
thomwolf's avatar
thomwolf committed
178
179
180

            _, _ = model(inputs)

thomwolf's avatar
thomwolf committed
181
            inputs = [input_ids_1, input_mask]
thomwolf's avatar
thomwolf committed
182
183
184
185

            outputs, mems_1 = model(inputs)

            result = {
thomwolf's avatar
thomwolf committed
186
                "mems_1": [mem.numpy() for mem in mems_1],
thomwolf's avatar
thomwolf committed
187
188
189
                "outputs": outputs.numpy(),
            }

thomwolf's avatar
thomwolf committed
190
191
            config.mem_len = 0
            model = TFXLNetModel(config)
192
193
194
            no_mems_outputs = model(inputs)
            self.parent.assertEqual(len(no_mems_outputs), 1)

thomwolf's avatar
thomwolf committed
195
            self.parent.assertListEqual(
196
197
                list(result["outputs"].shape), [self.batch_size, self.seq_length, self.hidden_size]
            )
thomwolf's avatar
thomwolf committed
198
199
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_1"]),
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
                [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
            )

        def create_and_check_xlnet_lm_head(
            self,
            config,
            input_ids_1,
            input_ids_2,
            input_ids_q,
            perm_mask,
            input_mask,
            target_mapping,
            segment_ids,
            lm_labels,
            sequence_labels,
            is_impossible_labels,
        ):
217
218
            model = TFXLNetLMHeadModel(config)

219
            inputs_1 = {"input_ids": input_ids_1, "token_type_ids": segment_ids}
220
221
222

            all_logits_1, mems_1 = model(inputs_1)

223
            inputs_2 = {"input_ids": input_ids_2, "mems": mems_1, "token_type_ids": segment_ids}
224
225
226

            all_logits_2, mems_2 = model(inputs_2)

227
            inputs_3 = {"input_ids": input_ids_q, "perm_mask": perm_mask, "target_mapping": target_mapping}
228
229
230
231
232
233
234
235
236
237
238

            logits, _ = model(inputs_3)

            result = {
                "mems_1": [mem.numpy() for mem in mems_1],
                "all_logits_1": all_logits_1.numpy(),
                "mems_2": [mem.numpy() for mem in mems_2],
                "all_logits_2": all_logits_2.numpy(),
            }

            self.parent.assertListEqual(
239
240
                list(result["all_logits_1"].shape), [self.batch_size, self.seq_length, self.vocab_size]
            )
241
242
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_1"]),
243
244
                [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
            )
245
246

            self.parent.assertListEqual(
247
248
                list(result["all_logits_2"].shape), [self.batch_size, self.seq_length, self.vocab_size]
            )
249
250
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_2"]),
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
                [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
            )

        def create_and_check_xlnet_qa(
            self,
            config,
            input_ids_1,
            input_ids_2,
            input_ids_q,
            perm_mask,
            input_mask,
            target_mapping,
            segment_ids,
            lm_labels,
            sequence_labels,
            is_impossible_labels,
        ):
268
269
            model = TFXLNetForQuestionAnsweringSimple(config)

270
            inputs = {"input_ids": input_ids_1, "attention_mask": input_mask, "token_type_ids": segment_ids}
271
272
273
274
275
276
277
278
            start_logits, end_logits, mems = model(inputs)

            result = {
                "start_logits": start_logits.numpy(),
                "end_logits": end_logits.numpy(),
                "mems": [m.numpy() for m in mems],
            }

279
280
            self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
            self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
281
282
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems"]),
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
                [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
            )

        def create_and_check_xlnet_sequence_classif(
            self,
            config,
            input_ids_1,
            input_ids_2,
            input_ids_q,
            perm_mask,
            input_mask,
            target_mapping,
            segment_ids,
            lm_labels,
            sequence_labels,
            is_impossible_labels,
        ):
300
301
302
303
304
305
306
307
308
            model = TFXLNetForSequenceClassification(config)

            logits, mems_1 = model(input_ids_1)

            result = {
                "mems_1": [mem.numpy() for mem in mems_1],
                "logits": logits.numpy(),
            }

309
            self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.type_sequence_label_size])
310
311
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_1"]),
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
                [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
            )

        def create_and_check_xlnet_for_token_classification(
            self,
            config,
            input_ids_1,
            input_ids_2,
            input_ids_q,
            perm_mask,
            input_mask,
            target_mapping,
            segment_ids,
            lm_labels,
            sequence_labels,
            is_impossible_labels,
        ):
329
330
            config.num_labels = input_ids_1.shape[1]
            model = TFXLNetForTokenClassification(config)
331
332
333
334
335
            inputs = {
                "input_ids": input_ids_1,
                "attention_mask": input_mask,
                # 'token_type_ids': token_type_ids
            }
336
337
338
339
340
341
            logits, mems_1 = model(inputs)
            result = {
                "mems_1": [mem.numpy() for mem in mems_1],
                "logits": logits.numpy(),
            }
            self.parent.assertListEqual(
342
343
                list(result["logits"].shape), [self.batch_size, self.seq_length, config.num_labels]
            )
344
345
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_1"]),
346
347
                [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
            )
348

thomwolf's avatar
thomwolf committed
349
350
        def prepare_config_and_inputs_for_common(self):
            config_and_inputs = self.prepare_config_and_inputs()
351
352
353
354
355
356
357
358
359
360
361
362
363
364
            (
                config,
                input_ids_1,
                input_ids_2,
                input_ids_q,
                perm_mask,
                input_mask,
                target_mapping,
                segment_ids,
                lm_labels,
                sequence_labels,
                is_impossible_labels,
            ) = config_and_inputs
            inputs_dict = {"input_ids": input_ids_1}
thomwolf's avatar
thomwolf committed
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
            return config, inputs_dict

    def setUp(self):
        self.model_tester = TFXLNetModelTest.TFXLNetModelTester(self)
        self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)

    def test_config(self):
        self.config_tester.run_common_tests()

    def test_xlnet_base_model(self):
        self.model_tester.set_seed()
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)

    def test_xlnet_lm_head(self):
        self.model_tester.set_seed()
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
382
        self.model_tester.create_and_check_xlnet_lm_head(*config_and_inputs)
thomwolf's avatar
thomwolf committed
383
384
385
386
387
388

    def test_xlnet_sequence_classif(self):
        self.model_tester.set_seed()
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)

389
390
391
392
    def test_xlnet_token_classification(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlnet_for_token_classification(*config_and_inputs)

thomwolf's avatar
thomwolf committed
393
394
395
396
397
    def test_xlnet_qa(self):
        self.model_tester.set_seed()
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlnet_qa(*config_and_inputs)

398
    @slow
thomwolf's avatar
thomwolf committed
399
400
    def test_model_from_pretrained(self):
        for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
401
            model = TFXLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
thomwolf's avatar
thomwolf committed
402
            self.assertIsNotNone(model)