test_modeling_bert.py 17.3 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

thomwolf's avatar
thomwolf committed
16

17
18
import unittest

19
from transformers import is_torch_available
20
from transformers.testing_utils import require_torch, slow, torch_device
thomwolf's avatar
thomwolf committed
21

22
from .test_configuration_common import ConfigTester
23
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
thomwolf's avatar
thomwolf committed
24

Aymeric Augustin's avatar
Aymeric Augustin committed
25

26
if is_torch_available():
27
28
29
    from transformers import (
        BertConfig,
        BertForMaskedLM,
30
        BertForMultipleChoice,
31
32
33
34
35
        BertForNextSentencePrediction,
        BertForPreTraining,
        BertForQuestionAnswering,
        BertForSequenceClassification,
        BertForTokenClassification,
36
37
        BertLMHeadModel,
        BertModel,
38
    )
39
    from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
thomwolf's avatar
thomwolf committed
40

thomwolf's avatar
thomwolf committed
41

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class BertModelTester:
    def __init__(
        self,
        parent,
        batch_size=13,
        seq_length=7,
        is_training=True,
        use_input_mask=True,
        use_token_type_ids=True,
        use_labels=True,
        vocab_size=99,
        hidden_size=32,
        num_hidden_layers=5,
        num_attention_heads=4,
        intermediate_size=37,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=16,
        type_sequence_label_size=2,
        initializer_range=0.02,
        num_labels=3,
        num_choices=4,
        scope=None,
    ):
        self.parent = parent
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.is_training = is_training
        self.use_input_mask = use_input_mask
        self.use_token_type_ids = use_token_type_ids
        self.use_labels = use_labels
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.type_sequence_label_size = type_sequence_label_size
        self.initializer_range = initializer_range
        self.num_labels = num_labels
        self.num_choices = num_choices
        self.scope = scope

    def prepare_config_and_inputs(self):
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

        input_mask = None
        if self.use_input_mask:
96
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

        token_type_ids = None
        if self.use_token_type_ids:
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)

        sequence_labels = None
        token_labels = None
        choice_labels = None
        if self.use_labels:
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
            choice_labels = ids_tensor([self.batch_size], self.num_choices)

        config = BertConfig(
            vocab_size=self.vocab_size,
            hidden_size=self.hidden_size,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            intermediate_size=self.intermediate_size,
            hidden_act=self.hidden_act,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            type_vocab_size=self.type_vocab_size,
            is_decoder=False,
            initializer_range=self.initializer_range,
Sylvain Gugger's avatar
Sylvain Gugger committed
123
            return_dict=True,
124
        )
thomwolf's avatar
thomwolf committed
125

126
127
128
        return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

    def prepare_config_and_inputs_for_decoder(self):
129
        (
130
131
132
133
134
135
136
137
138
139
140
141
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
        ) = self.prepare_config_and_inputs()

        config.is_decoder = True
        encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
        encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
thomwolf's avatar
thomwolf committed
142

143
        return (
144
145
146
147
148
149
150
151
152
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
            encoder_hidden_states,
            encoder_attention_mask,
153
154
        )

155
    def create_and_check_model(
156
157
158
159
160
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertModel(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
161
162
163
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
        result = model(input_ids, token_type_ids=token_type_ids)
        result = model(input_ids)
Stas Bekman's avatar
Stas Bekman committed
164
165
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
166

167
    def create_and_check_model_as_decoder(
168
169
170
171
172
173
174
175
176
177
178
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
179
        config.add_cross_attention = True
180
181
182
        model = BertModel(config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
183
        result = model(
184
185
186
187
188
189
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
190
        result = model(
191
192
193
194
195
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            encoder_hidden_states=encoder_hidden_states,
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
196
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
Stas Bekman's avatar
Stas Bekman committed
197
198
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
199

200
    def create_and_check_for_causal_lm(
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
        model = BertLMHeadModel(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
215
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
Stas Bekman's avatar
Stas Bekman committed
216
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
217

218
    def create_and_check_for_masked_lm(
219
220
221
222
223
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertForMaskedLM(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
224
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
Stas Bekman's avatar
Stas Bekman committed
225
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
226

227
    def create_and_check_model_for_causal_lm_as_decoder(
228
229
230
231
232
233
234
235
236
237
238
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
239
        config.add_cross_attention = True
240
        model = BertLMHeadModel(config=config)
241
242
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
243
        result = model(
244
245
246
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
247
            labels=token_labels,
248
249
250
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
251
        result = model(
252
253
254
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
255
            labels=token_labels,
256
257
            encoder_hidden_states=encoder_hidden_states,
        )
Stas Bekman's avatar
Stas Bekman committed
258
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
259

260
    def create_and_check_for_next_sequence_prediction(
261
262
263
264
265
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertForNextSentencePrediction(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
266
        result = model(
Lysandre's avatar
Lysandre committed
267
268
269
270
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            next_sentence_label=sequence_labels,
271
        )
Stas Bekman's avatar
Stas Bekman committed
272
        self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
273

274
    def create_and_check_for_pretraining(
275
276
277
278
279
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertForPreTraining(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
280
        result = model(
281
282
283
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
284
            labels=token_labels,
285
286
            next_sentence_label=sequence_labels,
        )
Stas Bekman's avatar
Stas Bekman committed
287
288
        self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
        self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
289

290
    def create_and_check_for_question_answering(
291
292
293
294
295
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertForQuestionAnswering(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
296
        result = model(
297
298
299
300
301
302
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            start_positions=sequence_labels,
            end_positions=sequence_labels,
        )
Stas Bekman's avatar
Stas Bekman committed
303
304
        self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
        self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
305

306
    def create_and_check_for_sequence_classification(
307
308
309
310
311
312
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.num_labels = self.num_labels
        model = BertForSequenceClassification(config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
313
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
Stas Bekman's avatar
Stas Bekman committed
314
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
315

316
    def create_and_check_for_token_classification(
317
318
319
320
321
322
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.num_labels = self.num_labels
        model = BertForTokenClassification(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
323
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
Stas Bekman's avatar
Stas Bekman committed
324
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
325

326
    def create_and_check_for_multiple_choice(
327
328
329
330
331
332
333
334
335
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.num_choices = self.num_choices
        model = BertForMultipleChoice(config=config)
        model.to(torch_device)
        model.eval()
        multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
        multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
        multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
Sylvain Gugger's avatar
Sylvain Gugger committed
336
        result = model(
337
338
339
340
341
            multiple_choice_inputs_ids,
            attention_mask=multiple_choice_input_mask,
            token_type_ids=multiple_choice_token_type_ids,
            labels=choice_labels,
        )
Stas Bekman's avatar
Stas Bekman committed
342
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
343
344
345
346

    def prepare_config_and_inputs_for_common(self):
        config_and_inputs = self.prepare_config_and_inputs()
        (
347
348
349
350
351
352
353
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
354
355
356
357
358
359
360
361
362
363
364
365
        ) = config_and_inputs
        inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
        return config, inputs_dict


@require_torch
class BertModelTest(ModelTesterMixin, unittest.TestCase):

    all_model_classes = (
        (
            BertModel,
            BertForMaskedLM,
366
            BertForMultipleChoice,
367
368
369
370
371
372
373
374
375
            BertForNextSentencePrediction,
            BertForPreTraining,
            BertForQuestionAnswering,
            BertForSequenceClassification,
            BertForTokenClassification,
        )
        if is_torch_available()
        else ()
    )
thomwolf's avatar
thomwolf committed
376

thomwolf's avatar
thomwolf committed
377
    def setUp(self):
378
        self.model_tester = BertModelTester(self)
thomwolf's avatar
thomwolf committed
379
        self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
thomwolf's avatar
thomwolf committed
380
381

    def test_config(self):
thomwolf's avatar
thomwolf committed
382
        self.config_tester.run_common_tests()
thomwolf's avatar
thomwolf committed
383

384
    def test_model(self):
thomwolf's avatar
thomwolf committed
385
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
386
        self.model_tester.create_and_check_model(*config_and_inputs)
thomwolf's avatar
thomwolf committed
387

388
    def test_model_as_decoder(self):
389
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
390
        self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
391

392
    def test_model_as_decoder_with_default_input_mask(self):
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
        # This regression test was failing with PyTorch < 1.3
        (
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
            encoder_hidden_states,
            encoder_attention_mask,
        ) = self.model_tester.prepare_config_and_inputs_for_decoder()

        input_mask = None

408
        self.model_tester.create_and_check_model_as_decoder(
409
410
411
412
413
414
415
416
417
418
419
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
            encoder_hidden_states,
            encoder_attention_mask,
        )

420
421
    def test_for_causal_lm(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
422
        self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
423

thomwolf's avatar
thomwolf committed
424
425
    def test_for_masked_lm(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
426
        self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
thomwolf's avatar
thomwolf committed
427

428
    def test_for_causal_lm_decoder(self):
429
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
430
        self.model_tester.create_and_check_model_for_causal_lm_as_decoder(*config_and_inputs)
431

thomwolf's avatar
thomwolf committed
432
433
    def test_for_multiple_choice(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
434
        self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
thomwolf's avatar
thomwolf committed
435

thomwolf's avatar
thomwolf committed
436
437
    def test_for_next_sequence_prediction(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
438
        self.model_tester.create_and_check_for_next_sequence_prediction(*config_and_inputs)
thomwolf's avatar
thomwolf committed
439

thomwolf's avatar
thomwolf committed
440
441
    def test_for_pretraining(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
442
        self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
thomwolf's avatar
thomwolf committed
443

thomwolf's avatar
thomwolf committed
444
445
    def test_for_question_answering(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
446
        self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
thomwolf's avatar
thomwolf committed
447

thomwolf's avatar
thomwolf committed
448
449
    def test_for_sequence_classification(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
450
        self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
thomwolf's avatar
thomwolf committed
451

thomwolf's avatar
thomwolf committed
452
453
    def test_for_token_classification(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
454
        self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
thomwolf's avatar
thomwolf committed
455

456
    @slow
thomwolf's avatar
thomwolf committed
457
    def test_model_from_pretrained(self):
458
        for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
459
            model = BertModel.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
460
            self.assertIsNotNone(model)