test_modeling_xlm.py 16.3 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

thomwolf's avatar
thomwolf committed
16

17
18
import unittest

19
from transformers import is_torch_available
thomwolf's avatar
thomwolf committed
20

21
from .test_configuration_common import ConfigTester
22
from .test_modeling_common import ModelTesterMixin, ids_tensor
23
from .utils import require_torch, slow, torch_device
Aymeric Augustin's avatar
Aymeric Augustin committed
24
25


26
if is_torch_available():
27
    import torch
28
29
30
31
    from transformers import (
        XLMConfig,
        XLMModel,
        XLMWithLMHeadModel,
32
        XLMForTokenClassification,
33
34
35
36
        XLMForQuestionAnswering,
        XLMForSequenceClassification,
        XLMForQuestionAnsweringSimple,
    )
37
    from transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_LIST
thomwolf's avatar
thomwolf committed
38
39


40
@require_torch
41
class XLMModelTest(ModelTesterMixin, unittest.TestCase):
thomwolf's avatar
thomwolf committed
42

43
44
45
46
47
48
49
50
51
52
53
    all_model_classes = (
        (
            XLMModel,
            XLMWithLMHeadModel,
            XLMForQuestionAnswering,
            XLMForSequenceClassification,
            XLMForQuestionAnsweringSimple,
        )
        if is_torch_available()
        else ()
    )
54
55
56
    all_generative_model_classes = (
        (XLMWithLMHeadModel,) if is_torch_available() else ()
    )  # TODO (PVP): Check other models whether language generation is also applicable
thomwolf's avatar
thomwolf committed
57

thomwolf's avatar
thomwolf committed
58
    class XLMModelTester(object):
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        def __init__(
            self,
            parent,
            batch_size=13,
            seq_length=7,
            is_training=True,
            use_input_lengths=True,
            use_token_type_ids=True,
            use_labels=True,
            gelu_activation=True,
            sinusoidal_embeddings=False,
            causal=False,
            asm=False,
            n_langs=2,
            vocab_size=99,
            n_special=0,
            hidden_size=32,
            num_hidden_layers=5,
            num_attention_heads=4,
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
            max_position_embeddings=512,
            type_vocab_size=16,
            type_sequence_label_size=2,
            initializer_range=0.02,
            num_labels=3,
            num_choices=4,
            summary_type="last",
            use_proj=True,
            scope=None,
89
            bos_token_id=0,
90
        ):
thomwolf's avatar
thomwolf committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
            self.parent = parent
            self.batch_size = batch_size
            self.seq_length = seq_length
            self.is_training = is_training
            self.use_input_lengths = use_input_lengths
            self.use_token_type_ids = use_token_type_ids
            self.use_labels = use_labels
            self.gelu_activation = gelu_activation
            self.sinusoidal_embeddings = sinusoidal_embeddings
            self.asm = asm
            self.n_langs = n_langs
            self.vocab_size = vocab_size
            self.n_special = n_special
            self.summary_type = summary_type
            self.causal = causal
            self.use_proj = use_proj
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.hidden_dropout_prob = hidden_dropout_prob
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            self.max_position_embeddings = max_position_embeddings
            self.n_langs = n_langs
            self.type_sequence_label_size = type_sequence_label_size
            self.initializer_range = initializer_range
            self.summary_type = summary_type
            self.num_labels = num_labels
            self.num_choices = num_choices
            self.scope = scope
120
            self.bos_token_id = bos_token_id
thomwolf's avatar
thomwolf committed
121
122
123

        def prepare_config_and_inputs(self):
            input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
thomwolf's avatar
thomwolf committed
124
            input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
thomwolf's avatar
thomwolf committed
125
126
127

            input_lengths = None
            if self.use_input_lengths:
128
129
130
                input_lengths = (
                    ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2
                )  # small variation of seq_length
thomwolf's avatar
thomwolf committed
131
132
133
134
135
136
137

            token_type_ids = None
            if self.use_token_type_ids:
                token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.n_langs)

            sequence_labels = None
            token_labels = None
thomwolf's avatar
thomwolf committed
138
            is_impossible_labels = None
thomwolf's avatar
thomwolf committed
139
140
141
            if self.use_labels:
                sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
                token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
thomwolf's avatar
thomwolf committed
142
                is_impossible_labels = ids_tensor([self.batch_size], 2).float()
thomwolf's avatar
thomwolf committed
143
144

            config = XLMConfig(
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
                vocab_size=self.vocab_size,
                n_special=self.n_special,
                emb_dim=self.hidden_size,
                n_layers=self.num_hidden_layers,
                n_heads=self.num_attention_heads,
                dropout=self.hidden_dropout_prob,
                attention_dropout=self.attention_probs_dropout_prob,
                gelu_activation=self.gelu_activation,
                sinusoidal_embeddings=self.sinusoidal_embeddings,
                asm=self.asm,
                causal=self.causal,
                n_langs=self.n_langs,
                max_position_embeddings=self.max_position_embeddings,
                initializer_range=self.initializer_range,
                summary_type=self.summary_type,
                use_proj=self.use_proj,
161
                bos_token_id=self.bos_token_id,
162
163
164
165
166
167
168
169
170
171
172
173
            )

            return (
                config,
                input_ids,
                token_type_ids,
                input_lengths,
                sequence_labels,
                token_labels,
                is_impossible_labels,
                input_mask,
            )
thomwolf's avatar
thomwolf committed
174
175

        def check_loss_output(self, result):
176
177
178
179
180
181
182
183
184
185
186
187
188
            self.parent.assertListEqual(list(result["loss"].size()), [])

        def create_and_check_xlm_model(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
thomwolf's avatar
thomwolf committed
189
            model = XLMModel(config=config)
190
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
191
192
            model.eval()
            outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids)
thomwolf's avatar
thomwolf committed
193
194
            outputs = model(input_ids, langs=token_type_ids)
            outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
195
196
197
198
199
            sequence_output = outputs[0]
            result = {
                "sequence_output": sequence_output,
            }
            self.parent.assertListEqual(
200
201
202
203
204
205
206
207
208
209
210
211
212
213
                list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
            )

        def create_and_check_xlm_lm_head(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
thomwolf's avatar
thomwolf committed
214
            model = XLMWithLMHeadModel(config)
215
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
216
217
218
219
220
221
222
223
224
            model.eval()

            loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)

            result = {
                "loss": loss,
                "logits": logits,
            }

225
            self.parent.assertListEqual(list(result["loss"].size()), [])
thomwolf's avatar
thomwolf committed
226
            self.parent.assertListEqual(
227
228
229
230
231
232
233
234
235
236
237
238
239
240
                list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
            )

        def create_and_check_xlm_simple_qa(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
241
            model = XLMForQuestionAnsweringSimple(config)
242
            model.to(torch_device)
243
244
245
246
            model.eval()

            outputs = model(input_ids)

247
            outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
248
249
250
251
252
253
254
            loss, start_logits, end_logits = outputs

            result = {
                "loss": loss,
                "start_logits": start_logits,
                "end_logits": end_logits,
            }
255
256
            self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
            self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
257
258
            self.check_loss_output(result)

259
260
261
262
263
264
265
266
267
268
269
        def create_and_check_xlm_qa(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
thomwolf's avatar
thomwolf committed
270
            model = XLMForQuestionAnswering(config)
271
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
272
273
274
            model.eval()

            outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
275
            start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = outputs
thomwolf's avatar
thomwolf committed
276

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
            outputs = model(
                input_ids,
                start_positions=sequence_labels,
                end_positions=sequence_labels,
                cls_index=sequence_labels,
                is_impossible=is_impossible_labels,
                p_mask=input_mask,
            )

            outputs = model(
                input_ids,
                start_positions=sequence_labels,
                end_positions=sequence_labels,
                cls_index=sequence_labels,
                is_impossible=is_impossible_labels,
            )
thomwolf's avatar
thomwolf committed
293

294
            (total_loss,) = outputs
thomwolf's avatar
thomwolf committed
295

296
            outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
thomwolf's avatar
thomwolf committed
297

298
            (total_loss,) = outputs
thomwolf's avatar
thomwolf committed
299
300
301

            result = {
                "loss": total_loss,
302
303
304
305
                "start_top_log_probs": start_top_log_probs,
                "start_top_index": start_top_index,
                "end_top_log_probs": end_top_log_probs,
                "end_top_index": end_top_index,
thomwolf's avatar
thomwolf committed
306
307
308
                "cls_logits": cls_logits,
            }

309
            self.parent.assertListEqual(list(result["loss"].size()), [])
thomwolf's avatar
thomwolf committed
310
            self.parent.assertListEqual(
311
312
                list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
            )
thomwolf's avatar
thomwolf committed
313
            self.parent.assertListEqual(
314
315
                list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top]
            )
316
317
            self.parent.assertListEqual(
                list(result["end_top_log_probs"].size()),
318
319
                [self.batch_size, model.config.start_n_top * model.config.end_n_top],
            )
320
321
            self.parent.assertListEqual(
                list(result["end_top_index"].size()),
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
                [self.batch_size, model.config.start_n_top * model.config.end_n_top],
            )
            self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size])

        def create_and_check_xlm_sequence_classif(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
thomwolf's avatar
thomwolf committed
337
            model = XLMForSequenceClassification(config)
338
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
339
340
341
342
343
344
345
346
347
348
            model.eval()

            (logits,) = model(input_ids)
            loss, logits = model(input_ids, labels=sequence_labels)

            result = {
                "loss": loss,
                "logits": logits,
            }

349
            self.parent.assertListEqual(list(result["loss"].size()), [])
thomwolf's avatar
thomwolf committed
350
            self.parent.assertListEqual(
351
352
                list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]
            )
thomwolf's avatar
thomwolf committed
353

354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
        def create_and_check_xlm_for_token_classification(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
            config.num_labels = self.num_labels
            model = XLMForTokenClassification(config)
            model.to(torch_device)
            model.eval()

            loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels)
            result = {
                "loss": loss,
                "logits": logits,
            }
            self.parent.assertListEqual(
                list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
            )
            self.check_loss_output(result)

thomwolf's avatar
thomwolf committed
380
381
        def prepare_config_and_inputs_for_common(self):
            config_and_inputs = self.prepare_config_and_inputs()
382
383
384
385
386
387
388
389
390
391
392
            (
                config,
                input_ids,
                token_type_ids,
                input_lengths,
                sequence_labels,
                token_labels,
                is_impossible_labels,
                input_mask,
            ) = config_and_inputs
            inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths}
thomwolf's avatar
thomwolf committed
393
            return config, inputs_dict
thomwolf's avatar
thomwolf committed
394

thomwolf's avatar
thomwolf committed
395
396
397
    def setUp(self):
        self.model_tester = XLMModelTest.XLMModelTester(self)
        self.config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37)
thomwolf's avatar
thomwolf committed
398
399

    def test_config(self):
thomwolf's avatar
thomwolf committed
400
        self.config_tester.run_common_tests()
thomwolf's avatar
thomwolf committed
401

thomwolf's avatar
thomwolf committed
402
403
404
    def test_xlm_model(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_model(*config_and_inputs)
thomwolf's avatar
thomwolf committed
405

thomwolf's avatar
thomwolf committed
406
407
408
    def test_xlm_lm_head(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_lm_head(*config_and_inputs)
thomwolf's avatar
thomwolf committed
409

410
411
412
413
    def test_xlm_simple_qa(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_simple_qa(*config_and_inputs)

thomwolf's avatar
thomwolf committed
414
415
416
    def test_xlm_qa(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_qa(*config_and_inputs)
thomwolf's avatar
thomwolf committed
417

thomwolf's avatar
thomwolf committed
418
419
420
    def test_xlm_sequence_classif(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_sequence_classif(*config_and_inputs)
thomwolf's avatar
thomwolf committed
421

422
423
424
425
    def test_xlm_for_token_classification(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_for_token_classification(*config_and_inputs)

426
    @slow
thomwolf's avatar
thomwolf committed
427
    def test_model_from_pretrained(self):
428
        for model_name in XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
429
            model = XLMModel.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
430
            self.assertIsNotNone(model)
431
432


433
@require_torch
434
435
436
437
class XLMModelLanguageGenerationTest(unittest.TestCase):
    @slow
    def test_lm_generate_xlm_mlm_en_2048(self):
        model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
438
        model.to(torch_device)
439
        input_ids = torch.tensor([[14, 447]], dtype=torch.long, device=torch_device)  # the president
440
441
        expected_output_ids = [
            14,
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
            447,
            14,
            447,
            14,
            447,
            14,
            447,
            14,
            447,
            14,
            447,
            14,
            447,
            14,
            447,
            14,
            447,
            14,
            447,
        ]  # the president the president the president the president the president the president the president the president the president the president
        # TODO(PVP): this and other input_ids I tried for generation give pretty bad results. Not sure why. Model might just not be made for auto-regressive inference
        output_ids = model.generate(input_ids, do_sample=False)
464
        self.assertListEqual(output_ids[0].cpu().numpy().tolist(), expected_output_ids)