test_modeling_bert.py 17.5 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

thomwolf's avatar
thomwolf committed
16

17
18
import unittest

19
from transformers import is_torch_available
20
from transformers.testing_utils import require_torch, slow, torch_device
thomwolf's avatar
thomwolf committed
21

22
from .test_configuration_common import ConfigTester
23
from .test_generation_utils import GenerationTesterMixin
24
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
thomwolf's avatar
thomwolf committed
25

Aymeric Augustin's avatar
Aymeric Augustin committed
26

27
if is_torch_available():
28
29
30
    from transformers import (
        BertConfig,
        BertForMaskedLM,
31
        BertForMultipleChoice,
32
33
34
35
36
        BertForNextSentencePrediction,
        BertForPreTraining,
        BertForQuestionAnswering,
        BertForSequenceClassification,
        BertForTokenClassification,
37
38
        BertLMHeadModel,
        BertModel,
39
    )
40
    from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
thomwolf's avatar
thomwolf committed
41

thomwolf's avatar
thomwolf committed
42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class BertModelTester:
    def __init__(
        self,
        parent,
        batch_size=13,
        seq_length=7,
        is_training=True,
        use_input_mask=True,
        use_token_type_ids=True,
        use_labels=True,
        vocab_size=99,
        hidden_size=32,
        num_hidden_layers=5,
        num_attention_heads=4,
        intermediate_size=37,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=16,
        type_sequence_label_size=2,
        initializer_range=0.02,
        num_labels=3,
        num_choices=4,
        scope=None,
    ):
        self.parent = parent
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.is_training = is_training
        self.use_input_mask = use_input_mask
        self.use_token_type_ids = use_token_type_ids
        self.use_labels = use_labels
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.type_sequence_label_size = type_sequence_label_size
        self.initializer_range = initializer_range
        self.num_labels = num_labels
        self.num_choices = num_choices
        self.scope = scope

    def prepare_config_and_inputs(self):
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

        input_mask = None
        if self.use_input_mask:
97
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

        token_type_ids = None
        if self.use_token_type_ids:
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)

        sequence_labels = None
        token_labels = None
        choice_labels = None
        if self.use_labels:
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
            choice_labels = ids_tensor([self.batch_size], self.num_choices)

        config = BertConfig(
            vocab_size=self.vocab_size,
            hidden_size=self.hidden_size,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            intermediate_size=self.intermediate_size,
            hidden_act=self.hidden_act,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            type_vocab_size=self.type_vocab_size,
            is_decoder=False,
            initializer_range=self.initializer_range,
Sylvain Gugger's avatar
Sylvain Gugger committed
124
            return_dict=True,
125
        )
thomwolf's avatar
thomwolf committed
126

127
128
129
        return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

    def prepare_config_and_inputs_for_decoder(self):
130
        (
131
132
133
134
135
136
137
138
139
140
141
142
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
        ) = self.prepare_config_and_inputs()

        config.is_decoder = True
        encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
        encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
thomwolf's avatar
thomwolf committed
143

144
        return (
145
146
147
148
149
150
151
152
153
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
            encoder_hidden_states,
            encoder_attention_mask,
154
155
        )

156
    def create_and_check_model(
157
158
159
160
161
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertModel(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
162
163
164
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
        result = model(input_ids, token_type_ids=token_type_ids)
        result = model(input_ids)
Stas Bekman's avatar
Stas Bekman committed
165
166
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
167

168
    def create_and_check_model_as_decoder(
169
170
171
172
173
174
175
176
177
178
179
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
180
        config.add_cross_attention = True
181
182
183
        model = BertModel(config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
184
        result = model(
185
186
187
188
189
190
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
191
        result = model(
192
193
194
195
196
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            encoder_hidden_states=encoder_hidden_states,
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
197
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
Stas Bekman's avatar
Stas Bekman committed
198
199
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
200

201
    def create_and_check_for_causal_lm(
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
        model = BertLMHeadModel(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
216
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
Stas Bekman's avatar
Stas Bekman committed
217
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
218

219
    def create_and_check_for_masked_lm(
220
221
222
223
224
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertForMaskedLM(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
225
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
Stas Bekman's avatar
Stas Bekman committed
226
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
227

228
    def create_and_check_model_for_causal_lm_as_decoder(
229
230
231
232
233
234
235
236
237
238
239
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
240
        config.add_cross_attention = True
241
        model = BertLMHeadModel(config=config)
242
243
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
244
        result = model(
245
246
247
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
248
            labels=token_labels,
249
250
251
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
252
        result = model(
253
254
255
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
256
            labels=token_labels,
257
258
            encoder_hidden_states=encoder_hidden_states,
        )
Stas Bekman's avatar
Stas Bekman committed
259
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
260

261
    def create_and_check_for_next_sequence_prediction(
262
263
264
265
266
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertForNextSentencePrediction(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
267
        result = model(
Lysandre's avatar
Lysandre committed
268
269
270
271
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            next_sentence_label=sequence_labels,
272
        )
Stas Bekman's avatar
Stas Bekman committed
273
        self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
274

275
    def create_and_check_for_pretraining(
276
277
278
279
280
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertForPreTraining(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
281
        result = model(
282
283
284
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
285
            labels=token_labels,
286
287
            next_sentence_label=sequence_labels,
        )
Stas Bekman's avatar
Stas Bekman committed
288
289
        self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
        self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
290

291
    def create_and_check_for_question_answering(
292
293
294
295
296
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertForQuestionAnswering(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
297
        result = model(
298
299
300
301
302
303
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            start_positions=sequence_labels,
            end_positions=sequence_labels,
        )
Stas Bekman's avatar
Stas Bekman committed
304
305
        self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
        self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
306

307
    def create_and_check_for_sequence_classification(
308
309
310
311
312
313
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.num_labels = self.num_labels
        model = BertForSequenceClassification(config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
314
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
Stas Bekman's avatar
Stas Bekman committed
315
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
316

317
    def create_and_check_for_token_classification(
318
319
320
321
322
323
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.num_labels = self.num_labels
        model = BertForTokenClassification(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
324
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
Stas Bekman's avatar
Stas Bekman committed
325
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
326

327
    def create_and_check_for_multiple_choice(
328
329
330
331
332
333
334
335
336
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.num_choices = self.num_choices
        model = BertForMultipleChoice(config=config)
        model.to(torch_device)
        model.eval()
        multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
        multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
        multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
Sylvain Gugger's avatar
Sylvain Gugger committed
337
        result = model(
338
339
340
341
342
            multiple_choice_inputs_ids,
            attention_mask=multiple_choice_input_mask,
            token_type_ids=multiple_choice_token_type_ids,
            labels=choice_labels,
        )
Stas Bekman's avatar
Stas Bekman committed
343
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
344
345
346
347

    def prepare_config_and_inputs_for_common(self):
        config_and_inputs = self.prepare_config_and_inputs()
        (
348
349
350
351
352
353
354
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
355
356
357
358
359
360
        ) = config_and_inputs
        inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
        return config, inputs_dict


@require_torch
361
class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
362
363
364
365

    all_model_classes = (
        (
            BertModel,
366
            BertLMHeadModel,
367
            BertForMaskedLM,
368
            BertForMultipleChoice,
369
370
371
372
373
374
375
376
377
            BertForNextSentencePrediction,
            BertForPreTraining,
            BertForQuestionAnswering,
            BertForSequenceClassification,
            BertForTokenClassification,
        )
        if is_torch_available()
        else ()
    )
378
    all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
thomwolf's avatar
thomwolf committed
379

thomwolf's avatar
thomwolf committed
380
    def setUp(self):
381
        self.model_tester = BertModelTester(self)
thomwolf's avatar
thomwolf committed
382
        self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
thomwolf's avatar
thomwolf committed
383
384

    def test_config(self):
thomwolf's avatar
thomwolf committed
385
        self.config_tester.run_common_tests()
thomwolf's avatar
thomwolf committed
386

387
    def test_model(self):
thomwolf's avatar
thomwolf committed
388
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
389
        self.model_tester.create_and_check_model(*config_and_inputs)
thomwolf's avatar
thomwolf committed
390

391
    def test_model_as_decoder(self):
392
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
393
        self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
394

395
    def test_model_as_decoder_with_default_input_mask(self):
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
        # This regression test was failing with PyTorch < 1.3
        (
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
            encoder_hidden_states,
            encoder_attention_mask,
        ) = self.model_tester.prepare_config_and_inputs_for_decoder()

        input_mask = None

411
        self.model_tester.create_and_check_model_as_decoder(
412
413
414
415
416
417
418
419
420
421
422
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
            encoder_hidden_states,
            encoder_attention_mask,
        )

423
424
    def test_for_causal_lm(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
425
        self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
426

thomwolf's avatar
thomwolf committed
427
428
    def test_for_masked_lm(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
429
        self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
thomwolf's avatar
thomwolf committed
430

431
    def test_for_causal_lm_decoder(self):
432
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
433
        self.model_tester.create_and_check_model_for_causal_lm_as_decoder(*config_and_inputs)
434

thomwolf's avatar
thomwolf committed
435
436
    def test_for_multiple_choice(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
437
        self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
thomwolf's avatar
thomwolf committed
438

thomwolf's avatar
thomwolf committed
439
440
    def test_for_next_sequence_prediction(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
441
        self.model_tester.create_and_check_for_next_sequence_prediction(*config_and_inputs)
thomwolf's avatar
thomwolf committed
442

thomwolf's avatar
thomwolf committed
443
444
    def test_for_pretraining(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
445
        self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
thomwolf's avatar
thomwolf committed
446

thomwolf's avatar
thomwolf committed
447
448
    def test_for_question_answering(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
449
        self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
thomwolf's avatar
thomwolf committed
450

thomwolf's avatar
thomwolf committed
451
452
    def test_for_sequence_classification(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
453
        self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
thomwolf's avatar
thomwolf committed
454

thomwolf's avatar
thomwolf committed
455
456
    def test_for_token_classification(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
457
        self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
thomwolf's avatar
thomwolf committed
458

459
    @slow
thomwolf's avatar
thomwolf committed
460
    def test_model_from_pretrained(self):
461
        for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
462
            model = BertModel.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
463
            self.assertIsNotNone(model)