test_modeling_bert.py 17.4 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

thomwolf's avatar
thomwolf committed
16

17
18
import unittest

19
from transformers import is_torch_available
20
from transformers.testing_utils import require_torch, slow, torch_device
thomwolf's avatar
thomwolf committed
21

22
from .test_configuration_common import ConfigTester
23
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
thomwolf's avatar
thomwolf committed
24

Aymeric Augustin's avatar
Aymeric Augustin committed
25

26
if is_torch_available():
27
28
29
    from transformers import (
        BertConfig,
        BertModel,
30
        BertLMHeadModel,
31
32
33
34
35
36
37
38
        BertForMaskedLM,
        BertForNextSentencePrediction,
        BertForPreTraining,
        BertForQuestionAnswering,
        BertForSequenceClassification,
        BertForTokenClassification,
        BertForMultipleChoice,
    )
39
    from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
thomwolf's avatar
thomwolf committed
40

thomwolf's avatar
thomwolf committed
41

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class BertModelTester:
    def __init__(
        self,
        parent,
        batch_size=13,
        seq_length=7,
        is_training=True,
        use_input_mask=True,
        use_token_type_ids=True,
        use_labels=True,
        vocab_size=99,
        hidden_size=32,
        num_hidden_layers=5,
        num_attention_heads=4,
        intermediate_size=37,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=16,
        type_sequence_label_size=2,
        initializer_range=0.02,
        num_labels=3,
        num_choices=4,
        scope=None,
    ):
        self.parent = parent
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.is_training = is_training
        self.use_input_mask = use_input_mask
        self.use_token_type_ids = use_token_type_ids
        self.use_labels = use_labels
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.type_sequence_label_size = type_sequence_label_size
        self.initializer_range = initializer_range
        self.num_labels = num_labels
        self.num_choices = num_choices
        self.scope = scope

    def prepare_config_and_inputs(self):
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

        input_mask = None
        if self.use_input_mask:
            input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)

        token_type_ids = None
        if self.use_token_type_ids:
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)

        sequence_labels = None
        token_labels = None
        choice_labels = None
        if self.use_labels:
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
            choice_labels = ids_tensor([self.batch_size], self.num_choices)

        config = BertConfig(
            vocab_size=self.vocab_size,
            hidden_size=self.hidden_size,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            intermediate_size=self.intermediate_size,
            hidden_act=self.hidden_act,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            type_vocab_size=self.type_vocab_size,
            is_decoder=False,
            initializer_range=self.initializer_range,
Sylvain Gugger's avatar
Sylvain Gugger committed
123
            return_dict=True,
124
        )
thomwolf's avatar
thomwolf committed
125

126
127
128
        return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

    def prepare_config_and_inputs_for_decoder(self):
129
        (
130
131
132
133
134
135
136
137
138
139
140
141
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
        ) = self.prepare_config_and_inputs()

        config.is_decoder = True
        encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
        encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
thomwolf's avatar
thomwolf committed
142

143
        return (
144
145
146
147
148
149
150
151
152
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
            encoder_hidden_states,
            encoder_attention_mask,
153
154
155
156
157
158
159
160
        )

    def create_and_check_bert_model(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertModel(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
161
162
163
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
        result = model(input_ids, token_type_ids=token_type_ids)
        result = model(input_ids)
Stas Bekman's avatar
Stas Bekman committed
164
165
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
166
167
168
169
170
171
172
173
174
175
176
177
178

    def create_and_check_bert_model_as_decoder(
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
179
        config.add_cross_attention = True
180
181
182
        model = BertModel(config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
183
        result = model(
184
185
186
187
188
189
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
190
        result = model(
191
192
193
194
195
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            encoder_hidden_states=encoder_hidden_states,
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
196
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
Stas Bekman's avatar
Stas Bekman committed
197
198
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
199

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    def create_and_check_bert_for_causal_lm(
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
        model = BertLMHeadModel(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
215
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
Stas Bekman's avatar
Stas Bekman committed
216
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
217

218
219
220
221
222
223
    def create_and_check_bert_for_masked_lm(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertForMaskedLM(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
224
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
Stas Bekman's avatar
Stas Bekman committed
225
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
226

227
    def create_and_check_bert_model_for_causal_lm_as_decoder(
228
229
230
231
232
233
234
235
236
237
238
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
239
        config.add_cross_attention = True
240
        model = BertLMHeadModel(config=config)
241
242
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
243
        result = model(
244
245
246
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
247
            labels=token_labels,
248
249
250
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
251
        result = model(
252
253
254
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
255
            labels=token_labels,
256
257
            encoder_hidden_states=encoder_hidden_states,
        )
Stas Bekman's avatar
Stas Bekman committed
258
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
259
260
261
262
263
264
265

    def create_and_check_bert_for_next_sequence_prediction(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertForNextSentencePrediction(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
266
        result = model(
267
268
            input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
        )
Stas Bekman's avatar
Stas Bekman committed
269
        self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
270
271
272
273
274
275
276

    def create_and_check_bert_for_pretraining(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertForPreTraining(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
277
        result = model(
278
279
280
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
281
            labels=token_labels,
282
283
            next_sentence_label=sequence_labels,
        )
Stas Bekman's avatar
Stas Bekman committed
284
285
        self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
        self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
286
287
288
289
290
291
292

    def create_and_check_bert_for_question_answering(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = BertForQuestionAnswering(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
293
        result = model(
294
295
296
297
298
299
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            start_positions=sequence_labels,
            end_positions=sequence_labels,
        )
Stas Bekman's avatar
Stas Bekman committed
300
301
        self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
        self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
302
303
304
305
306
307
308
309

    def create_and_check_bert_for_sequence_classification(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.num_labels = self.num_labels
        model = BertForSequenceClassification(config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
310
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
Stas Bekman's avatar
Stas Bekman committed
311
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
312
313
314
315
316
317
318
319

    def create_and_check_bert_for_token_classification(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.num_labels = self.num_labels
        model = BertForTokenClassification(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
320
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
Stas Bekman's avatar
Stas Bekman committed
321
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
322
323
324
325
326
327
328
329
330
331
332

    def create_and_check_bert_for_multiple_choice(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.num_choices = self.num_choices
        model = BertForMultipleChoice(config=config)
        model.to(torch_device)
        model.eval()
        multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
        multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
        multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
Sylvain Gugger's avatar
Sylvain Gugger committed
333
        result = model(
334
335
336
337
338
            multiple_choice_inputs_ids,
            attention_mask=multiple_choice_input_mask,
            token_type_ids=multiple_choice_token_type_ids,
            labels=choice_labels,
        )
Stas Bekman's avatar
Stas Bekman committed
339
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
340
341
342
343

    def prepare_config_and_inputs_for_common(self):
        config_and_inputs = self.prepare_config_and_inputs()
        (
344
345
346
347
348
349
350
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
351
352
353
354
355
356
357
358
359
360
361
362
        ) = config_and_inputs
        inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
        return config, inputs_dict


@require_torch
class BertModelTest(ModelTesterMixin, unittest.TestCase):

    all_model_classes = (
        (
            BertModel,
            BertForMaskedLM,
363
            BertForMultipleChoice,
364
365
366
367
368
369
370
371
372
            BertForNextSentencePrediction,
            BertForPreTraining,
            BertForQuestionAnswering,
            BertForSequenceClassification,
            BertForTokenClassification,
        )
        if is_torch_available()
        else ()
    )
Pradhy729's avatar
Pradhy729 committed
373
    test_chunking = True
thomwolf's avatar
thomwolf committed
374

thomwolf's avatar
thomwolf committed
375
    def setUp(self):
376
        self.model_tester = BertModelTester(self)
thomwolf's avatar
thomwolf committed
377
        self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
thomwolf's avatar
thomwolf committed
378
379

    def test_config(self):
thomwolf's avatar
thomwolf committed
380
        self.config_tester.run_common_tests()
thomwolf's avatar
thomwolf committed
381

382
    def test_bert_model(self):
thomwolf's avatar
thomwolf committed
383
384
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_bert_model(*config_and_inputs)
thomwolf's avatar
thomwolf committed
385

386
387
388
389
    def test_bert_model_as_decoder(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
        self.model_tester.create_and_check_bert_model_as_decoder(*config_and_inputs)

390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    def test_bert_model_as_decoder_with_default_input_mask(self):
        # This regression test was failing with PyTorch < 1.3
        (
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
            encoder_hidden_states,
            encoder_attention_mask,
        ) = self.model_tester.prepare_config_and_inputs_for_decoder()

        input_mask = None

        self.model_tester.create_and_check_bert_model_as_decoder(
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
            encoder_hidden_states,
            encoder_attention_mask,
        )

418
419
420
421
    def test_for_causal_lm(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
        self.model_tester.create_and_check_bert_for_causal_lm(*config_and_inputs)

thomwolf's avatar
thomwolf committed
422
423
424
    def test_for_masked_lm(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
thomwolf's avatar
thomwolf committed
425

426
    def test_for_causal_lm_decoder(self):
427
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
428
        self.model_tester.create_and_check_bert_model_for_causal_lm_as_decoder(*config_and_inputs)
429

thomwolf's avatar
thomwolf committed
430
431
432
    def test_for_multiple_choice(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
thomwolf's avatar
thomwolf committed
433

thomwolf's avatar
thomwolf committed
434
435
436
    def test_for_next_sequence_prediction(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
thomwolf's avatar
thomwolf committed
437

thomwolf's avatar
thomwolf committed
438
439
440
    def test_for_pretraining(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
thomwolf's avatar
thomwolf committed
441

thomwolf's avatar
thomwolf committed
442
443
444
    def test_for_question_answering(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs)
thomwolf's avatar
thomwolf committed
445

thomwolf's avatar
thomwolf committed
446
447
448
    def test_for_sequence_classification(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
thomwolf's avatar
thomwolf committed
449

thomwolf's avatar
thomwolf committed
450
451
452
    def test_for_token_classification(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
thomwolf's avatar
thomwolf committed
453

454
    @slow
thomwolf's avatar
thomwolf committed
455
    def test_model_from_pretrained(self):
456
        for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
457
            model = BertModel.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
458
            self.assertIsNotNone(model)