test_modeling_xlm.py 18 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17
import unittest

18
from transformers import XLMConfig, is_torch_available
19
from transformers.testing_utils import require_torch, slow, torch_device
thomwolf's avatar
thomwolf committed
20

21
from ...generation.test_utils import GenerationTesterMixin
Yih-Dar's avatar
Yih-Dar committed
22
23
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
Aymeric Augustin's avatar
Aymeric Augustin committed
24
25


26
if is_torch_available():
27
    import torch
28

29
    from transformers import (
30
        XLMForMultipleChoice,
31
32
        XLMForQuestionAnswering,
        XLMForQuestionAnsweringSimple,
33
34
35
36
        XLMForSequenceClassification,
        XLMForTokenClassification,
        XLMModel,
        XLMWithLMHeadModel,
37
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
38
    from transformers.models.xlm.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_LIST
thomwolf's avatar
thomwolf committed
39
40


41
42
class XLMModelTester:
    def __init__(
Lysandre's avatar
Lysandre committed
43
44
        self,
        parent,
Yih-Dar's avatar
Yih-Dar committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        batch_size=13,
        seq_length=7,
        is_training=True,
        use_input_lengths=True,
        use_token_type_ids=True,
        use_labels=True,
        gelu_activation=True,
        sinusoidal_embeddings=False,
        causal=False,
        asm=False,
        n_langs=2,
        vocab_size=99,
        n_special=0,
        hidden_size=32,
        num_hidden_layers=5,
        num_attention_heads=4,
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_sequence_label_size=2,
        initializer_range=0.02,
        num_labels=2,
        num_choices=4,
        summary_type="last",
        use_proj=True,
        scope=None,
        bos_token_id=0,
72
73
    ):
        self.parent = parent
Yih-Dar's avatar
Yih-Dar committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.is_training = is_training
        self.use_input_lengths = use_input_lengths
        self.use_token_type_ids = use_token_type_ids
        self.use_labels = use_labels
        self.gelu_activation = gelu_activation
        self.sinusoidal_embeddings = sinusoidal_embeddings
        self.causal = causal
        self.asm = asm
        self.n_langs = n_langs
        self.vocab_size = vocab_size
        self.n_special = n_special
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_sequence_label_size = type_sequence_label_size
        self.initializer_range = initializer_range
        self.num_labels = num_labels
        self.num_choices = num_choices
        self.summary_type = summary_type
        self.use_proj = use_proj
        self.scope = scope
        self.bos_token_id = bos_token_id
101
102
103

    def prepare_config_and_inputs(self):
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
104
        input_mask = random_attention_mask([self.batch_size, self.seq_length])
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

        input_lengths = None
        if self.use_input_lengths:
            input_lengths = (
                ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2
            )  # small variation of seq_length

        token_type_ids = None
        if self.use_token_type_ids:
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.n_langs)

        sequence_labels = None
        token_labels = None
        is_impossible_labels = None
        if self.use_labels:
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
            is_impossible_labels = ids_tensor([self.batch_size], 2).float()
123
            choice_labels = ids_tensor([self.batch_size], self.num_choices)
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        config = self.get_config()

        return (
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            choice_labels,
            input_mask,
        )

    def get_config(self):
        return XLMConfig(
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
            vocab_size=self.vocab_size,
            n_special=self.n_special,
            emb_dim=self.hidden_size,
            n_layers=self.num_hidden_layers,
            n_heads=self.num_attention_heads,
            dropout=self.hidden_dropout_prob,
            attention_dropout=self.attention_probs_dropout_prob,
            gelu_activation=self.gelu_activation,
            sinusoidal_embeddings=self.sinusoidal_embeddings,
            asm=self.asm,
            causal=self.causal,
            n_langs=self.n_langs,
            max_position_embeddings=self.max_position_embeddings,
            initializer_range=self.initializer_range,
            summary_type=self.summary_type,
            use_proj=self.use_proj,
157
            num_labels=self.num_labels,
158
            bos_token_id=self.bos_token_id,
159
        )
thomwolf's avatar
thomwolf committed
160

161
162
163
164
165
166
167
168
169
    def create_and_check_xlm_model(
        self,
        config,
        input_ids,
        token_type_ids,
        input_lengths,
        sequence_labels,
        token_labels,
        is_impossible_labels,
170
        choice_labels,
171
172
173
174
175
        input_mask,
    ):
        model = XLMModel(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
176
177
178
        result = model(input_ids, lengths=input_lengths, langs=token_type_ids)
        result = model(input_ids, langs=token_type_ids)
        result = model(input_ids)
Stas Bekman's avatar
Stas Bekman committed
179
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
180
181
182
183
184
185
186
187
188
189

    def create_and_check_xlm_lm_head(
        self,
        config,
        input_ids,
        token_type_ids,
        input_lengths,
        sequence_labels,
        token_labels,
        is_impossible_labels,
190
        choice_labels,
191
192
193
194
195
196
        input_mask,
    ):
        model = XLMWithLMHeadModel(config)
        model.to(torch_device)
        model.eval()

Sylvain Gugger's avatar
Sylvain Gugger committed
197
        result = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
Stas Bekman's avatar
Stas Bekman committed
198
199
        self.parent.assertEqual(result.loss.shape, ())
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
200
201
202
203
204
205
206
207
208
209

    def create_and_check_xlm_simple_qa(
        self,
        config,
        input_ids,
        token_type_ids,
        input_lengths,
        sequence_labels,
        token_labels,
        is_impossible_labels,
210
        choice_labels,
211
212
213
214
215
216
217
218
219
        input_mask,
    ):
        model = XLMForQuestionAnsweringSimple(config)
        model.to(torch_device)
        model.eval()

        outputs = model(input_ids)

        outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
220
        result = outputs
Stas Bekman's avatar
Stas Bekman committed
221
222
        self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
        self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
223
224
225
226
227
228
229
230
231
232

    def create_and_check_xlm_qa(
        self,
        config,
        input_ids,
        token_type_ids,
        input_lengths,
        sequence_labels,
        token_labels,
        is_impossible_labels,
233
        choice_labels,
234
235
236
237
238
239
        input_mask,
    ):
        model = XLMForQuestionAnswering(config)
        model.to(torch_device)
        model.eval()

Sylvain Gugger's avatar
Sylvain Gugger committed
240
        result = model(input_ids)
241

Sylvain Gugger's avatar
Sylvain Gugger committed
242
        result_with_labels = model(
243
            input_ids,
244
245
246
247
248
249
            start_positions=sequence_labels,
            end_positions=sequence_labels,
            cls_index=sequence_labels,
            is_impossible=is_impossible_labels,
            p_mask=input_mask,
        )
thomwolf's avatar
thomwolf committed
250

Sylvain Gugger's avatar
Sylvain Gugger committed
251
        result_with_labels = model(
252
253
254
255
256
257
            input_ids,
            start_positions=sequence_labels,
            end_positions=sequence_labels,
            cls_index=sequence_labels,
            is_impossible=is_impossible_labels,
        )
thomwolf's avatar
thomwolf committed
258

Sylvain Gugger's avatar
Sylvain Gugger committed
259
        (total_loss,) = result_with_labels.to_tuple()
thomwolf's avatar
thomwolf committed
260

Sylvain Gugger's avatar
Sylvain Gugger committed
261
        result_with_labels = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
262

Sylvain Gugger's avatar
Sylvain Gugger committed
263
        (total_loss,) = result_with_labels.to_tuple()
264

Stas Bekman's avatar
Stas Bekman committed
265
266
267
268
269
        self.parent.assertEqual(result_with_labels.loss.shape, ())
        self.parent.assertEqual(result.start_top_log_probs.shape, (self.batch_size, model.config.start_n_top))
        self.parent.assertEqual(result.start_top_index.shape, (self.batch_size, model.config.start_n_top))
        self.parent.assertEqual(
            result.end_top_log_probs.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)
270
        )
Stas Bekman's avatar
Stas Bekman committed
271
272
        self.parent.assertEqual(
            result.end_top_index.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)
273
        )
Stas Bekman's avatar
Stas Bekman committed
274
        self.parent.assertEqual(result.cls_logits.shape, (self.batch_size,))
275
276
277
278
279
280
281
282
283
284

    def create_and_check_xlm_sequence_classif(
        self,
        config,
        input_ids,
        token_type_ids,
        input_lengths,
        sequence_labels,
        token_labels,
        is_impossible_labels,
285
        choice_labels,
286
287
288
289
290
291
        input_mask,
    ):
        model = XLMForSequenceClassification(config)
        model.to(torch_device)
        model.eval()

Sylvain Gugger's avatar
Sylvain Gugger committed
292
293
        result = model(input_ids)
        result = model(input_ids, labels=sequence_labels)
Stas Bekman's avatar
Stas Bekman committed
294
295
        self.parent.assertEqual(result.loss.shape, ())
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
296

297
    def create_and_check_xlm_token_classif(
298
299
300
301
302
303
304
305
        self,
        config,
        input_ids,
        token_type_ids,
        input_lengths,
        sequence_labels,
        token_labels,
        is_impossible_labels,
306
        choice_labels,
307
308
309
310
311
312
313
        input_mask,
    ):
        config.num_labels = self.num_labels
        model = XLMForTokenClassification(config)
        model.to(torch_device)
        model.eval()

Sylvain Gugger's avatar
Sylvain Gugger committed
314
        result = model(input_ids, attention_mask=input_mask, labels=token_labels)
Stas Bekman's avatar
Stas Bekman committed
315
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
316

317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    def create_and_check_xlm_for_multiple_choice(
        self,
        config,
        input_ids,
        token_type_ids,
        input_lengths,
        sequence_labels,
        token_labels,
        is_impossible_labels,
        choice_labels,
        input_mask,
    ):
        config.num_choices = self.num_choices
        model = XLMForMultipleChoice(config=config)
        model.to(torch_device)
        model.eval()
        multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
        multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
        multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
Sylvain Gugger's avatar
Sylvain Gugger committed
336
        result = model(
337
338
339
340
341
            multiple_choice_inputs_ids,
            attention_mask=multiple_choice_input_mask,
            token_type_ids=multiple_choice_token_type_ids,
            labels=choice_labels,
        )
Stas Bekman's avatar
Stas Bekman committed
342
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
343

344
345
346
    def prepare_config_and_inputs_for_common(self):
        config_and_inputs = self.prepare_config_and_inputs()
        (
347
348
349
350
351
352
353
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
354
            choice_labels,
355
            input_mask,
356
357
358
359
360
361
        ) = config_and_inputs
        inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths}
        return config, inputs_dict


@require_torch
362
class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
363
364
365
366
367
368
369
370

    all_model_classes = (
        (
            XLMModel,
            XLMWithLMHeadModel,
            XLMForQuestionAnswering,
            XLMForSequenceClassification,
            XLMForQuestionAnsweringSimple,
371
            XLMForTokenClassification,
372
            XLMForMultipleChoice,
373
374
375
376
377
378
379
        )
        if is_torch_available()
        else ()
    )
    all_generative_model_classes = (
        (XLMWithLMHeadModel,) if is_torch_available() else ()
    )  # TODO (PVP): Check other models whether language generation is also applicable
thomwolf's avatar
thomwolf committed
380

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    # XLM has 2 QA models -> need to manually set the correct labels for one of them here
    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
        inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)

        if return_labels:
            if model_class.__name__ == "XLMForQuestionAnswering":
                inputs_dict["start_positions"] = torch.zeros(
                    self.model_tester.batch_size, dtype=torch.long, device=torch_device
                )
                inputs_dict["end_positions"] = torch.zeros(
                    self.model_tester.batch_size, dtype=torch.long, device=torch_device
                )

        return inputs_dict

thomwolf's avatar
thomwolf committed
396
    def setUp(self):
397
        self.model_tester = XLMModelTester(self)
thomwolf's avatar
thomwolf committed
398
        self.config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37)
thomwolf's avatar
thomwolf committed
399
400

    def test_config(self):
thomwolf's avatar
thomwolf committed
401
        self.config_tester.run_common_tests()
thomwolf's avatar
thomwolf committed
402

thomwolf's avatar
thomwolf committed
403
404
405
    def test_xlm_model(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_model(*config_and_inputs)
thomwolf's avatar
thomwolf committed
406

thomwolf's avatar
thomwolf committed
407
408
409
    def test_xlm_lm_head(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_lm_head(*config_and_inputs)
thomwolf's avatar
thomwolf committed
410

411
412
413
414
    def test_xlm_simple_qa(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_simple_qa(*config_and_inputs)

thomwolf's avatar
thomwolf committed
415
416
417
    def test_xlm_qa(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_qa(*config_and_inputs)
thomwolf's avatar
thomwolf committed
418

thomwolf's avatar
thomwolf committed
419
420
421
    def test_xlm_sequence_classif(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_sequence_classif(*config_and_inputs)
thomwolf's avatar
thomwolf committed
422

423
    def test_xlm_token_classif(self):
424
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
425
        self.model_tester.create_and_check_xlm_token_classif(*config_and_inputs)
426

427
428
429
430
    def test_xlm_for_multiple_choice(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)

431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    def _check_attentions_for_generate(
        self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
    ):
        self.assertIsInstance(attentions, tuple)
        self.assertListEqual(
            [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
        )
        self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)

        for idx, iter_attentions in enumerate(attentions):
            # adds PAD dummy token
            tgt_len = min_length + idx + 1
            src_len = min_length + idx + 1

            expected_shape = (
                batch_size * num_beam_groups,
                config.num_attention_heads,
                tgt_len,
                src_len,
            )
            # check attn size
            self.assertListEqual(
                [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
            )

    def _check_hidden_states_for_generate(
        self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
    ):
        self.assertIsInstance(hidden_states, tuple)
        self.assertListEqual(
            [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
            [True] * len(hidden_states),
        )
        self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)

        for idx, iter_hidden_states in enumerate(hidden_states):
            # adds PAD dummy token
            seq_len = min_length + idx + 1
            expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
            # check hidden size
            self.assertListEqual(
                [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
                [expected_shape] * len(iter_hidden_states),
            )
        pass

477
    @slow
thomwolf's avatar
thomwolf committed
478
    def test_model_from_pretrained(self):
479
        for model_name in XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
480
            model = XLMModel.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
481
            self.assertIsNotNone(model)
482
483


484
@require_torch
485
486
487
488
class XLMModelLanguageGenerationTest(unittest.TestCase):
    @slow
    def test_lm_generate_xlm_mlm_en_2048(self):
        model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
489
        model.to(torch_device)
490
        input_ids = torch.tensor([[14, 447]], dtype=torch.long, device=torch_device)  # the president
491
492
        expected_output_ids = [
            14,
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
            447,
            14,
            447,
            14,
            447,
            14,
            447,
            14,
            447,
            14,
            447,
            14,
            447,
            14,
            447,
            14,
            447,
            14,
            447,
        ]  # the president the president the president the president the president the president the president the president the president the president
        # TODO(PVP): this and other input_ids I tried for generation give pretty bad results. Not sure why. Model might just not be made for auto-regressive inference
        output_ids = model.generate(input_ids, do_sample=False)
515
        self.assertListEqual(output_ids[0].cpu().numpy().tolist(), expected_output_ids)