test_modeling_xlm.py 13.9 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15
from __future__ import absolute_import, division, print_function
thomwolf's avatar
thomwolf committed
16
17
18

import unittest

19
from transformers import is_torch_available
thomwolf's avatar
thomwolf committed
20

21
22
from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, ids_tensor
Aymeric Augustin's avatar
Aymeric Augustin committed
23
24
25
from .utils import CACHE_DIR, require_torch, slow, torch_device


26
if is_torch_available():
27
28
29
30
31
32
33
34
    from transformers import (
        XLMConfig,
        XLMModel,
        XLMWithLMHeadModel,
        XLMForQuestionAnswering,
        XLMForSequenceClassification,
        XLMForQuestionAnsweringSimple,
    )
35
    from transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
thomwolf's avatar
thomwolf committed
36
37


38
@require_torch
thomwolf's avatar
thomwolf committed
39
40
class XLMModelTest(CommonTestCases.CommonModelTester):

41
42
43
44
45
46
47
48
49
50
51
    all_model_classes = (
        (
            XLMModel,
            XLMWithLMHeadModel,
            XLMForQuestionAnswering,
            XLMForSequenceClassification,
            XLMForQuestionAnsweringSimple,
        )
        if is_torch_available()
        else ()
    )
thomwolf's avatar
thomwolf committed
52

thomwolf's avatar
thomwolf committed
53
    class XLMModelTester(object):
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        def __init__(
            self,
            parent,
            batch_size=13,
            seq_length=7,
            is_training=True,
            use_input_lengths=True,
            use_token_type_ids=True,
            use_labels=True,
            gelu_activation=True,
            sinusoidal_embeddings=False,
            causal=False,
            asm=False,
            n_langs=2,
            vocab_size=99,
            n_special=0,
            hidden_size=32,
            num_hidden_layers=5,
            num_attention_heads=4,
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
            max_position_embeddings=512,
            type_vocab_size=16,
            type_sequence_label_size=2,
            initializer_range=0.02,
            num_labels=3,
            num_choices=4,
            summary_type="last",
            use_proj=True,
            scope=None,
        ):
thomwolf's avatar
thomwolf committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
            self.parent = parent
            self.batch_size = batch_size
            self.seq_length = seq_length
            self.is_training = is_training
            self.use_input_lengths = use_input_lengths
            self.use_token_type_ids = use_token_type_ids
            self.use_labels = use_labels
            self.gelu_activation = gelu_activation
            self.sinusoidal_embeddings = sinusoidal_embeddings
            self.asm = asm
            self.n_langs = n_langs
            self.vocab_size = vocab_size
            self.n_special = n_special
            self.summary_type = summary_type
            self.causal = causal
            self.use_proj = use_proj
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.hidden_dropout_prob = hidden_dropout_prob
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            self.max_position_embeddings = max_position_embeddings
            self.n_langs = n_langs
            self.type_sequence_label_size = type_sequence_label_size
            self.initializer_range = initializer_range
            self.summary_type = summary_type
            self.num_labels = num_labels
            self.num_choices = num_choices
            self.scope = scope

        def prepare_config_and_inputs(self):
            input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
thomwolf's avatar
thomwolf committed
117
            input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
thomwolf's avatar
thomwolf committed
118
119
120

            input_lengths = None
            if self.use_input_lengths:
121
122
123
                input_lengths = (
                    ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2
                )  # small variation of seq_length
thomwolf's avatar
thomwolf committed
124
125
126
127
128
129
130

            token_type_ids = None
            if self.use_token_type_ids:
                token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.n_langs)

            sequence_labels = None
            token_labels = None
thomwolf's avatar
thomwolf committed
131
            is_impossible_labels = None
thomwolf's avatar
thomwolf committed
132
133
134
            if self.use_labels:
                sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
                token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
thomwolf's avatar
thomwolf committed
135
                is_impossible_labels = ids_tensor([self.batch_size], 2).float()
thomwolf's avatar
thomwolf committed
136
137

            config = XLMConfig(
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
                vocab_size=self.vocab_size,
                n_special=self.n_special,
                emb_dim=self.hidden_size,
                n_layers=self.num_hidden_layers,
                n_heads=self.num_attention_heads,
                dropout=self.hidden_dropout_prob,
                attention_dropout=self.attention_probs_dropout_prob,
                gelu_activation=self.gelu_activation,
                sinusoidal_embeddings=self.sinusoidal_embeddings,
                asm=self.asm,
                causal=self.causal,
                n_langs=self.n_langs,
                max_position_embeddings=self.max_position_embeddings,
                initializer_range=self.initializer_range,
                summary_type=self.summary_type,
                use_proj=self.use_proj,
            )

            return (
                config,
                input_ids,
                token_type_ids,
                input_lengths,
                sequence_labels,
                token_labels,
                is_impossible_labels,
                input_mask,
            )
thomwolf's avatar
thomwolf committed
166
167

        def check_loss_output(self, result):
168
169
170
171
172
173
174
175
176
177
178
179
180
            self.parent.assertListEqual(list(result["loss"].size()), [])

        def create_and_check_xlm_model(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
thomwolf's avatar
thomwolf committed
181
            model = XLMModel(config=config)
182
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
183
184
            model.eval()
            outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids)
thomwolf's avatar
thomwolf committed
185
186
            outputs = model(input_ids, langs=token_type_ids)
            outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
187
188
189
190
191
            sequence_output = outputs[0]
            result = {
                "sequence_output": sequence_output,
            }
            self.parent.assertListEqual(
192
193
194
195
196
197
198
199
200
201
202
203
204
205
                list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
            )

        def create_and_check_xlm_lm_head(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
thomwolf's avatar
thomwolf committed
206
            model = XLMWithLMHeadModel(config)
207
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
208
209
210
211
212
213
214
215
216
            model.eval()

            loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)

            result = {
                "loss": loss,
                "logits": logits,
            }

217
            self.parent.assertListEqual(list(result["loss"].size()), [])
thomwolf's avatar
thomwolf committed
218
            self.parent.assertListEqual(
219
220
221
222
223
224
225
226
227
228
229
230
231
232
                list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
            )

        def create_and_check_xlm_simple_qa(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
233
            model = XLMForQuestionAnsweringSimple(config)
234
            model.to(torch_device)
235
236
237
238
            model.eval()

            outputs = model(input_ids)

239
            outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
240
241
242
243
244
245
246
            loss, start_logits, end_logits = outputs

            result = {
                "loss": loss,
                "start_logits": start_logits,
                "end_logits": end_logits,
            }
247
248
            self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
            self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
249
250
            self.check_loss_output(result)

251
252
253
254
255
256
257
258
259
260
261
        def create_and_check_xlm_qa(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
thomwolf's avatar
thomwolf committed
262
            model = XLMForQuestionAnswering(config)
263
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
264
265
266
            model.eval()

            outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
267
            start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = outputs
thomwolf's avatar
thomwolf committed
268

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
            outputs = model(
                input_ids,
                start_positions=sequence_labels,
                end_positions=sequence_labels,
                cls_index=sequence_labels,
                is_impossible=is_impossible_labels,
                p_mask=input_mask,
            )

            outputs = model(
                input_ids,
                start_positions=sequence_labels,
                end_positions=sequence_labels,
                cls_index=sequence_labels,
                is_impossible=is_impossible_labels,
            )
thomwolf's avatar
thomwolf committed
285

286
            (total_loss,) = outputs
thomwolf's avatar
thomwolf committed
287

288
            outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
thomwolf's avatar
thomwolf committed
289

290
            (total_loss,) = outputs
thomwolf's avatar
thomwolf committed
291
292
293

            result = {
                "loss": total_loss,
294
295
296
297
                "start_top_log_probs": start_top_log_probs,
                "start_top_index": start_top_index,
                "end_top_log_probs": end_top_log_probs,
                "end_top_index": end_top_index,
thomwolf's avatar
thomwolf committed
298
299
300
                "cls_logits": cls_logits,
            }

301
            self.parent.assertListEqual(list(result["loss"].size()), [])
thomwolf's avatar
thomwolf committed
302
            self.parent.assertListEqual(
303
304
                list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
            )
thomwolf's avatar
thomwolf committed
305
            self.parent.assertListEqual(
306
307
                list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top]
            )
308
309
            self.parent.assertListEqual(
                list(result["end_top_log_probs"].size()),
310
311
                [self.batch_size, model.config.start_n_top * model.config.end_n_top],
            )
312
313
            self.parent.assertListEqual(
                list(result["end_top_index"].size()),
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
                [self.batch_size, model.config.start_n_top * model.config.end_n_top],
            )
            self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size])

        def create_and_check_xlm_sequence_classif(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
thomwolf's avatar
thomwolf committed
329
            model = XLMForSequenceClassification(config)
330
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
331
332
333
334
335
336
337
338
339
340
            model.eval()

            (logits,) = model(input_ids)
            loss, logits = model(input_ids, labels=sequence_labels)

            result = {
                "loss": loss,
                "logits": logits,
            }

341
            self.parent.assertListEqual(list(result["loss"].size()), [])
thomwolf's avatar
thomwolf committed
342
            self.parent.assertListEqual(
343
344
                list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]
            )
thomwolf's avatar
thomwolf committed
345

thomwolf's avatar
thomwolf committed
346
347
        def prepare_config_and_inputs_for_common(self):
            config_and_inputs = self.prepare_config_and_inputs()
348
349
350
351
352
353
354
355
356
357
358
            (
                config,
                input_ids,
                token_type_ids,
                input_lengths,
                sequence_labels,
                token_labels,
                is_impossible_labels,
                input_mask,
            ) = config_and_inputs
            inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths}
thomwolf's avatar
thomwolf committed
359
            return config, inputs_dict
thomwolf's avatar
thomwolf committed
360

thomwolf's avatar
thomwolf committed
361
362
363
    def setUp(self):
        self.model_tester = XLMModelTest.XLMModelTester(self)
        self.config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37)
thomwolf's avatar
thomwolf committed
364
365

    def test_config(self):
thomwolf's avatar
thomwolf committed
366
        self.config_tester.run_common_tests()
thomwolf's avatar
thomwolf committed
367

thomwolf's avatar
thomwolf committed
368
369
370
    def test_xlm_model(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_model(*config_and_inputs)
thomwolf's avatar
thomwolf committed
371

thomwolf's avatar
thomwolf committed
372
373
374
    def test_xlm_lm_head(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_lm_head(*config_and_inputs)
thomwolf's avatar
thomwolf committed
375

376
377
378
379
    def test_xlm_simple_qa(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_simple_qa(*config_and_inputs)

thomwolf's avatar
thomwolf committed
380
381
382
    def test_xlm_qa(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_qa(*config_and_inputs)
thomwolf's avatar
thomwolf committed
383

thomwolf's avatar
thomwolf committed
384
385
386
    def test_xlm_sequence_classif(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_xlm_sequence_classif(*config_and_inputs)
thomwolf's avatar
thomwolf committed
387

388
    @slow
thomwolf's avatar
thomwolf committed
389
390
    def test_model_from_pretrained(self):
        for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
391
            model = XLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
thomwolf's avatar
thomwolf committed
392
393
            self.assertIsNotNone(model)

thomwolf's avatar
thomwolf committed
394
395
396

if __name__ == "__main__":
    unittest.main()