test_modeling_xlm.py 14.1 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

thomwolf's avatar
thomwolf committed
16

17
18
import unittest

19
from transformers import is_torch_available
thomwolf's avatar
thomwolf committed
20

21
from .test_configuration_common import ConfigTester
22
from .test_modeling_common import ModelTesterMixin, ids_tensor
Aymeric Augustin's avatar
Aymeric Augustin committed
23
24
25
from .utils import CACHE_DIR, require_torch, slow, torch_device


26
if is_torch_available():
27
28
29
30
31
32
33
34
    from transformers import (
        XLMConfig,
        XLMModel,
        XLMWithLMHeadModel,
        XLMForQuestionAnswering,
        XLMForSequenceClassification,
        XLMForQuestionAnsweringSimple,
    )
35
    from transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
thomwolf's avatar
thomwolf committed
36
37


38
@require_torch
39
class XLMModelTest(ModelTesterMixin, unittest.TestCase):
thomwolf's avatar
thomwolf committed
40

41
42
43
44
45
46
47
48
49
50
51
    all_model_classes = (
        (
            XLMModel,
            XLMWithLMHeadModel,
            XLMForQuestionAnswering,
            XLMForSequenceClassification,
            XLMForQuestionAnsweringSimple,
        )
        if is_torch_available()
        else ()
    )
52
53
54
    all_generative_model_classes = (
        (XLMWithLMHeadModel,) if is_torch_available() else ()
    )  # TODO (PVP): Check other models whether language generation is also applicable
thomwolf's avatar
thomwolf committed
55

thomwolf's avatar
thomwolf committed
56
    class XLMModelTester(object):
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        def __init__(
            self,
            parent,
            batch_size=13,
            seq_length=7,
            is_training=True,
            use_input_lengths=True,
            use_token_type_ids=True,
            use_labels=True,
            gelu_activation=True,
            sinusoidal_embeddings=False,
            causal=False,
            asm=False,
            n_langs=2,
            vocab_size=99,
            n_special=0,
            hidden_size=32,
            num_hidden_layers=5,
            num_attention_heads=4,
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
            max_position_embeddings=512,
            type_vocab_size=16,
            type_sequence_label_size=2,
            initializer_range=0.02,
            num_labels=3,
            num_choices=4,
            summary_type="last",
            use_proj=True,
            scope=None,
87
            bos_token_id=0,
88
        ):
thomwolf's avatar
thomwolf committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            self.parent = parent
            self.batch_size = batch_size
            self.seq_length = seq_length
            self.is_training = is_training
            self.use_input_lengths = use_input_lengths
            self.use_token_type_ids = use_token_type_ids
            self.use_labels = use_labels
            self.gelu_activation = gelu_activation
            self.sinusoidal_embeddings = sinusoidal_embeddings
            self.asm = asm
            self.n_langs = n_langs
            self.vocab_size = vocab_size
            self.n_special = n_special
            self.summary_type = summary_type
            self.causal = causal
            self.use_proj = use_proj
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.hidden_dropout_prob = hidden_dropout_prob
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            self.max_position_embeddings = max_position_embeddings
            self.n_langs = n_langs
            self.type_sequence_label_size = type_sequence_label_size
            self.initializer_range = initializer_range
            self.summary_type = summary_type
            self.num_labels = num_labels
            self.num_choices = num_choices
            self.scope = scope
118
            self.bos_token_id = bos_token_id
thomwolf's avatar
thomwolf committed
119
120
121

        def prepare_config_and_inputs(self):
            input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
thomwolf's avatar
thomwolf committed
122
            input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
thomwolf's avatar
thomwolf committed
123
124
125

            input_lengths = None
            if self.use_input_lengths:
126
127
128
                input_lengths = (
                    ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2
                )  # small variation of seq_length
thomwolf's avatar
thomwolf committed
129
130
131
132
133
134
135

            token_type_ids = None
            if self.use_token_type_ids:
                token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.n_langs)

            sequence_labels = None
            token_labels = None
thomwolf's avatar
thomwolf committed
136
            is_impossible_labels = None
thomwolf's avatar
thomwolf committed
137
138
139
            if self.use_labels:
                sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
                token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
thomwolf's avatar
thomwolf committed
140
                is_impossible_labels = ids_tensor([self.batch_size], 2).float()
thomwolf's avatar
thomwolf committed
141
142

            config = XLMConfig(
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
                vocab_size=self.vocab_size,
                n_special=self.n_special,
                emb_dim=self.hidden_size,
                n_layers=self.num_hidden_layers,
                n_heads=self.num_attention_heads,
                dropout=self.hidden_dropout_prob,
                attention_dropout=self.attention_probs_dropout_prob,
                gelu_activation=self.gelu_activation,
                sinusoidal_embeddings=self.sinusoidal_embeddings,
                asm=self.asm,
                causal=self.causal,
                n_langs=self.n_langs,
                max_position_embeddings=self.max_position_embeddings,
                initializer_range=self.initializer_range,
                summary_type=self.summary_type,
                use_proj=self.use_proj,
159
                bos_token_id=self.bos_token_id,
160
161
162
163
164
165
166
167
168
169
170
171
            )

            return (
                config,
                input_ids,
                token_type_ids,
                input_lengths,
                sequence_labels,
                token_labels,
                is_impossible_labels,
                input_mask,
            )
thomwolf's avatar
thomwolf committed
172
173

        def check_loss_output(self, result):
174
175
176
177
178
179
180
181
182
183
184
185
186
            self.parent.assertListEqual(list(result["loss"].size()), [])

        def create_and_check_xlm_model(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
thomwolf's avatar
thomwolf committed
187
            model = XLMModel(config=config)
188
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
189
190
            model.eval()
            outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids)
thomwolf's avatar
thomwolf committed
191
192
            outputs = model(input_ids, langs=token_type_ids)
            outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
193
194
195
196
197
            sequence_output = outputs[0]
            result = {
                "sequence_output": sequence_output,
            }
            self.parent.assertListEqual(
198
199
200
201
202
203
204
205
206
207
208
209
210
211
                list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
            )

        def create_and_check_xlm_lm_head(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
thomwolf's avatar
thomwolf committed
212
            model = XLMWithLMHeadModel(config)
213
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
214
215
216
217
218
219
220
221
222
            model.eval()

            loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)

            result = {
                "loss": loss,
                "logits": logits,
            }

223
            self.parent.assertListEqual(list(result["loss"].size()), [])
thomwolf's avatar
thomwolf committed
224
            self.parent.assertListEqual(
225
226
227
228
229
230
231
232
233
234
235
236
237
238
                list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
            )

        def create_and_check_xlm_simple_qa(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
239
            model = XLMForQuestionAnsweringSimple(config)
240
            model.to(torch_device)
241
242
243
244
            model.eval()

            outputs = model(input_ids)

245
            outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
246
247
248
249
250
251
252
            loss, start_logits, end_logits = outputs

            result = {
                "loss": loss,
                "start_logits": start_logits,
                "end_logits": end_logits,
            }
253
254
            self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
            self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
255
256
            self.check_loss_output(result)

257
258
259
260
261
262
263
264
265
266
267
        def create_and_check_xlm_qa(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
thomwolf's avatar
thomwolf committed
268
            model = XLMForQuestionAnswering(config)
269
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
270
271
272
            model.eval()

            outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
273
            start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = outputs
thomwolf's avatar
thomwolf committed
274

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
            outputs = model(
                input_ids,
                start_positions=sequence_labels,
                end_positions=sequence_labels,
                cls_index=sequence_labels,
                is_impossible=is_impossible_labels,
                p_mask=input_mask,
            )

            outputs = model(
                input_ids,
                start_positions=sequence_labels,
                end_positions=sequence_labels,
                cls_index=sequence_labels,
                is_impossible=is_impossible_labels,
            )
thomwolf's avatar
thomwolf committed
291

292
            (total_loss,) = outputs
thomwolf's avatar
thomwolf committed
293

294
            outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
thomwolf's avatar
thomwolf committed
295

296
            (total_loss,) = outputs
thomwolf's avatar
thomwolf committed
297
298
299

            result = {
                "loss": total_loss,
300
301
302
303
                "start_top_log_probs": start_top_log_probs,
                "start_top_index": start_top_index,
                "end_top_log_probs": end_top_log_probs,
                "end_top_index": end_top_index,
thomwolf's avatar
thomwolf committed
304
305
306
                "cls_logits": cls_logits,
            }

307
            self.parent.assertListEqual(list(result["loss"].size()), [])
thomwolf's avatar
thomwolf committed
308
            self.parent.assertListEqual(
309
310
                list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
            )
thomwolf's avatar
thomwolf committed
311
            self.parent.assertListEqual(
312
313
                list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top]
            )
314
315
            self.parent.assertListEqual(
                list(result["end_top_log_probs"].size()),
316
317
                [self.batch_size, model.config.start_n_top * model.config.end_n_top],
            )
318
319
            self.parent.assertListEqual(
                list(result["end_top_index"].size()),
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
                [self.batch_size, model.config.start_n_top * model.config.end_n_top],
            )
            self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size])

        def create_and_check_xlm_sequence_classif(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
thomwolf's avatar
thomwolf committed
335
            model = XLMForSequenceClassification(config)
336
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
337
338
339
340
341
342
343
344
345
346
            model.eval()

            (logits,) = model(input_ids)
            loss, logits = model(input_ids, labels=sequence_labels)

            result = {
                "loss": loss,
                "logits": logits,
            }

347
            self.parent.assertListEqual(list(result["loss"].size()), [])
thomwolf's avatar
thomwolf committed
348
            self.parent.assertListEqual(
349
350
                list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]
            )
thomwolf's avatar
thomwolf committed
351

thomwolf's avatar
thomwolf committed
352
353
        def prepare_config_and_inputs_for_common(self):
            config_and_inputs = self.prepare_config_and_inputs()
354
355
356
357
358
359
360
361
362
363
364
            (
                config,
                input_ids,
                token_type_ids,
                input_lengths,
                sequence_labels,
                token_labels,
                is_impossible_labels,
                input_mask,
            ) = config_and_inputs
            inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths}
thomwolf's avatar
thomwolf committed
365
            return config, inputs_dict
thomwolf's avatar
thomwolf committed
366

thomwolf's avatar
thomwolf committed
367
368
369
    def setUp(self):
        self.model_tester = XLMModelTest.XLMModelTester(self)
        self.config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37)
thomwolf's avatar
thomwolf committed
370
371

    def test_config(self):
thomwolf's avatar
thomwolf committed
372
        self.config_tester.run_common_tests()
thomwolf's avatar
thomwolf committed
373

thomwolf's avatar
thomwolf committed
374
375
376
    def test_xlm_model(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_model(*config_and_inputs)
thomwolf's avatar
thomwolf committed
377

thomwolf's avatar
thomwolf committed
378
379
380
    def test_xlm_lm_head(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_lm_head(*config_and_inputs)
thomwolf's avatar
thomwolf committed
381

382
383
384
385
    def test_xlm_simple_qa(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_simple_qa(*config_and_inputs)

thomwolf's avatar
thomwolf committed
386
387
388
    def test_xlm_qa(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_qa(*config_and_inputs)
thomwolf's avatar
thomwolf committed
389

thomwolf's avatar
thomwolf committed
390
391
392
    def test_xlm_sequence_classif(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_sequence_classif(*config_and_inputs)
thomwolf's avatar
thomwolf committed
393

394
    @slow
thomwolf's avatar
thomwolf committed
395
396
    def test_model_from_pretrained(self):
        for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
397
            model = XLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
thomwolf's avatar
thomwolf committed
398
            self.assertIsNotNone(model)