modeling_test.py 21.7 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19
import os
20
import unittest
thomwolf's avatar
thomwolf committed
21
22
import json
import random
23
24
import shutil
import pytest
thomwolf's avatar
thomwolf committed
25

26
27
import torch

28
29
30
from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
                                     BertForNextSentencePrediction, BertForPreTraining,
                                     BertForQuestionAnswering, BertForSequenceClassification,
31
                                     BertForTokenClassification, BertForMultipleChoice)
32
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP
thomwolf's avatar
thomwolf committed
33
34


35
class BertModelTest(unittest.TestCase):
36
37
38
39
40
41
42
43
44
    class BertModelTester(object):

        def __init__(self,
                     parent,
                     batch_size=13,
                     seq_length=7,
                     is_training=True,
                     use_input_mask=True,
                     use_token_type_ids=True,
45
                     use_labels=True,
46
47
48
49
50
51
52
53
54
55
                     vocab_size=99,
                     hidden_size=32,
                     num_hidden_layers=5,
                     num_attention_heads=4,
                     intermediate_size=37,
                     hidden_act="gelu",
                     hidden_dropout_prob=0.1,
                     attention_probs_dropout_prob=0.1,
                     max_position_embeddings=512,
                     type_vocab_size=16,
56
                     type_sequence_label_size=2,
57
                     initializer_range=0.02,
58
                     num_labels=3,
59
                     num_choices=4,
60
61
62
63
64
65
66
                     scope=None):
            self.parent = parent
            self.batch_size = batch_size
            self.seq_length = seq_length
            self.is_training = is_training
            self.use_input_mask = use_input_mask
            self.use_token_type_ids = use_token_type_ids
67
            self.use_labels = use_labels
68
69
70
71
72
73
74
75
76
77
            self.vocab_size = vocab_size
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.intermediate_size = intermediate_size
            self.hidden_act = hidden_act
            self.hidden_dropout_prob = hidden_dropout_prob
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            self.max_position_embeddings = max_position_embeddings
            self.type_vocab_size = type_vocab_size
78
            self.type_sequence_label_size = type_sequence_label_size
79
            self.initializer_range = initializer_range
80
            self.num_labels = num_labels
81
            self.num_choices = num_choices
82
83
            self.scope = scope

84
        def prepare_config_and_inputs(self):
85
            input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
86
87
88

            input_mask = None
            if self.use_input_mask:
89
                input_mask = BertModelTest.ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
90
91
92

            token_type_ids = None
            if self.use_token_type_ids:
93
                token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
94

95
96
            sequence_labels = None
            token_labels = None
97
            choice_labels = None
98
99
100
            if self.use_labels:
                sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
                token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels)
101
                choice_labels = BertModelTest.ids_tensor([self.batch_size], self.num_choices)
102

thomwolf's avatar
thomwolf committed
103
104
            config = BertConfig(
                vocab_size_or_config_json_file=self.vocab_size,
105
106
107
108
109
110
111
112
113
114
115
                hidden_size=self.hidden_size,
                num_hidden_layers=self.num_hidden_layers,
                num_attention_heads=self.num_attention_heads,
                intermediate_size=self.intermediate_size,
                hidden_act=self.hidden_act,
                hidden_dropout_prob=self.hidden_dropout_prob,
                attention_probs_dropout_prob=self.attention_probs_dropout_prob,
                max_position_embeddings=self.max_position_embeddings,
                type_vocab_size=self.type_vocab_size,
                initializer_range=self.initializer_range)

116
            return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
117

118
119
120
121
        def check_loss_output(self, result):
            self.parent.assertListEqual(
                list(result["loss"].size()),
                [])
122

123
        def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
124
            model = BertModel(config=config)
thomwolf's avatar
thomwolf committed
125
            model.eval()
126
            all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
127
            outputs = {
128
129
130
                "sequence_output": all_encoder_layers[-1],
                "pooled_output": pooled_output,
                "all_encoder_layers": all_encoder_layers,
131
132
133
            }
            return outputs

134
135
136
137
        def check_bert_model_output(self, result):
            self.parent.assertListEqual(
                [size for layer in result["all_encoder_layers"] for size in layer.size()],
                [self.batch_size, self.seq_length, self.hidden_size] * self.num_hidden_layers)
138
139
            self.parent.assertListEqual(
                list(result["sequence_output"].size()),
140
                [self.batch_size, self.seq_length, self.hidden_size])
141
            self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
142

143

144
        def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
145
            model = BertForMaskedLM(config=config)
thomwolf's avatar
thomwolf committed
146
            model.eval()
147
148
149
150
151
152
153
154
155
156
157
158
159
            loss = model(input_ids, token_type_ids, input_mask, token_labels)
            prediction_scores = model(input_ids, token_type_ids, input_mask)
            outputs = {
                "loss": loss,
                "prediction_scores": prediction_scores,
            }
            return outputs

        def check_bert_for_masked_lm_output(self, result):
            self.parent.assertListEqual(
                list(result["prediction_scores"].size()),
                [self.batch_size, self.seq_length, self.vocab_size])

160
        def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
161
            model = BertForNextSentencePrediction(config=config)
thomwolf's avatar
thomwolf committed
162
            model.eval()
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
            seq_relationship_score = model(input_ids, token_type_ids, input_mask)
            outputs = {
                "loss": loss,
                "seq_relationship_score": seq_relationship_score,
            }
            return outputs

        def check_bert_for_next_sequence_prediction_output(self, result):
            self.parent.assertListEqual(
                list(result["seq_relationship_score"].size()),
                [self.batch_size, 2])


177
        def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
178
            model = BertForPreTraining(config=config)
thomwolf's avatar
thomwolf committed
179
            model.eval()
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
            loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
            prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask)
            outputs = {
                "loss": loss,
                "prediction_scores": prediction_scores,
                "seq_relationship_score": seq_relationship_score,
            }
            return outputs

        def check_bert_for_pretraining_output(self, result):
            self.parent.assertListEqual(
                list(result["prediction_scores"].size()),
                [self.batch_size, self.seq_length, self.vocab_size])
            self.parent.assertListEqual(
                list(result["seq_relationship_score"].size()),
                [self.batch_size, 2])


198
        def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
199
            model = BertForQuestionAnswering(config=config)
thomwolf's avatar
thomwolf committed
200
            model.eval()
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
            loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
            start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
            outputs = {
                "loss": loss,
                "start_logits": start_logits,
                "end_logits": end_logits,
            }
            return outputs

        def check_bert_for_question_answering_output(self, result):
            self.parent.assertListEqual(
                list(result["start_logits"].size()),
                [self.batch_size, self.seq_length])
            self.parent.assertListEqual(
                list(result["end_logits"].size()),
                [self.batch_size, self.seq_length])


219
        def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
220
            model = BertForSequenceClassification(config=config, num_labels=self.num_labels)
thomwolf's avatar
thomwolf committed
221
            model.eval()
222
223
224
225
226
227
228
229
230
231
232
233
234
235
            loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
            logits = model(input_ids, token_type_ids, input_mask)
            outputs = {
                "loss": loss,
                "logits": logits,
            }
            return outputs

        def check_bert_for_sequence_classification_output(self, result):
            self.parent.assertListEqual(
                list(result["logits"].size()),
                [self.batch_size, self.num_labels])


236
        def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
237
            model = BertForTokenClassification(config=config, num_labels=self.num_labels)
thomwolf's avatar
thomwolf committed
238
            model.eval()
239
240
241
242
243
244
245
246
247
248
249
250
251
252
            loss = model(input_ids, token_type_ids, input_mask, token_labels)
            logits = model(input_ids, token_type_ids, input_mask)
            outputs = {
                "loss": loss,
                "logits": logits,
            }
            return outputs

        def check_bert_for_token_classification_output(self, result):
            self.parent.assertListEqual(
                list(result["logits"].size()),
                [self.batch_size, self.seq_length, self.num_labels])


253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        def create_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
            model = BertForMultipleChoice(config=config, num_choices=self.num_choices)
            model.eval()
            multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
            multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
            multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
            loss = model(multiple_choice_inputs_ids,
                         multiple_choice_token_type_ids,
                         multiple_choice_input_mask,
                         choice_labels)
            logits = model(multiple_choice_inputs_ids,
                           multiple_choice_token_type_ids,
                           multiple_choice_input_mask)
            outputs = {
                "loss": loss,
                "logits": logits,
            }
            return outputs

        def check_bert_for_multiple_choice(self, result):
            self.parent.assertListEqual(
                list(result["logits"].size()),
                [self.batch_size, self.num_choices])


        def create_and_check_bert_for_attentions(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
            for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
                                BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
                                BertForTokenClassification):
                if model_class in [BertForSequenceClassification,
                                   BertForTokenClassification]:
                    model = model_class(config=config, num_labels=self.num_labels, output_attentions=True)
                else:
                    model = model_class(config=config, output_attentions=True)
                model.eval()
                output = model(input_ids, token_type_ids, input_mask)
                attentions = output[0]
                self.parent.assertEqual(len(attentions), self.num_hidden_layers)
                self.parent.assertListEqual(
                    list(attentions[0].size()),
                    [self.batch_size, self.num_attention_heads, self.seq_length, self.seq_length])


thomwolf's avatar
thomwolf committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        def create_and_check_bert_for_headmasking(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
            for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
                                BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
                                BertForTokenClassification):
                if model_class in [BertForSequenceClassification,
                                   BertForTokenClassification]:
                    model = model_class(config=config,
                                        num_labels=self.num_labels,
                                        keep_multihead_output=True)
                else:
                    model = model_class(config=config, keep_multihead_output=True)
                model.eval()
                head_mask = torch.ones(self.num_attention_heads).to(input_ids.device)
                head_mask[0] = 0.0
                head_mask[-1] = 0.0  # Mask all but the first and last heads
                output = model(input_ids, token_type_ids, input_mask, head_mask=head_mask)

                if isinstance(model, BertModel):
                    output = sum(t.sum() for t in output[0])
                elif isinstance(output, (list, tuple)):
                    output = sum(t.sum() for t in output)
                output = output.sum()
                output.backward()
                multihead_outputs = (model if isinstance(model, BertModel) else model.bert).get_multihead_outputs()

                self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
                self.parent.assertListEqual(
                    list(multihead_outputs[0].size()),
                    [self.batch_size, self.num_attention_heads,
                     self.seq_length, self.hidden_size // self.num_attention_heads])
                self.parent.assertEqual(
                    len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
                    0)
                self.parent.assertEqual(
                    len(multihead_outputs[0][:, 0, :, :].nonzero()),
                    self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
                self.parent.assertEqual(
                    len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
                    self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)


thomwolf's avatar
thomwolf committed
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
        def create_and_check_bert_for_head_pruning(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
            for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
                                BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
                                BertForTokenClassification):
                if model_class in [BertForSequenceClassification,
                                   BertForTokenClassification]:
                    model = model_class(config=config,
                                        num_labels=self.num_labels,
                                        keep_multihead_output=True)
                else:
                    model = model_class(config=config, keep_multihead_output=True)
                model.eval()
                bert_model = model if isinstance(model, BertModel) else model.bert
                heads_to_prune = {0: list(range(1, self.num_attention_heads)),
                                  -1: [0]}
                bert_model.prune_heads(heads_to_prune)
                output = model(input_ids, token_type_ids, input_mask)

                if isinstance(model, BertModel):
                    output = sum(t.sum() for t in output[0])
                elif isinstance(output, (list, tuple)):
                    output = sum(t.sum() for t in output)
                output = output.sum()
                output.backward()
                multihead_outputs = bert_model.get_multihead_outputs()

                self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
                self.parent.assertListEqual(
                    list(multihead_outputs[0].size()),
                    [self.batch_size, 1,
                     self.seq_length, self.hidden_size // self.num_attention_heads])
                self.parent.assertListEqual(
                    list(multihead_outputs[1].size()),
                    [self.batch_size, self.num_attention_heads,
                     self.seq_length, self.hidden_size // self.num_attention_heads])
                self.parent.assertListEqual(
                    list(multihead_outputs[-1].size()),
                    [self.batch_size, self.num_attention_heads-1,
                     self.seq_length, self.hidden_size // self.num_attention_heads])


378
379
380
381
    def test_default(self):
        self.run_tester(BertModelTest.BertModelTester(self))

    def test_config_to_json_string(self):
thomwolf's avatar
thomwolf committed
382
        config = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37)
383
384
385
386
        obj = json.loads(config.to_json_string())
        self.assertEqual(obj["vocab_size"], 99)
        self.assertEqual(obj["hidden_size"], 37)

387
388
389
390
391
392
393
394
    def test_config_to_json_file(self):
        config_first = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37)
        json_file_path = "/tmp/config.json"
        config_first.to_json_file(json_file_path)
        config_second = BertConfig.from_json_file(json_file_path)
        os.remove(json_file_path)
        self.assertEqual(config_second.to_dict(), config_first.to_dict())

395
396
397
398
399
400
401
402
    @pytest.mark.slow
    def test_model_from_pretrained(self):
        cache_dir = "/tmp/pytorch_pretrained_bert_test/"
        for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
            model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
            shutil.rmtree(cache_dir)
            self.assertIsNotNone(model)

403
    def run_tester(self, tester):
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
        config_and_inputs = tester.prepare_config_and_inputs()
        output_result = tester.create_bert_model(*config_and_inputs)
        tester.check_bert_model_output(output_result)

        output_result = tester.create_bert_for_masked_lm(*config_and_inputs)
        tester.check_bert_for_masked_lm_output(output_result)
        tester.check_loss_output(output_result)

        output_result = tester.create_bert_for_next_sequence_prediction(*config_and_inputs)
        tester.check_bert_for_next_sequence_prediction_output(output_result)
        tester.check_loss_output(output_result)

        output_result = tester.create_bert_for_pretraining(*config_and_inputs)
        tester.check_bert_for_pretraining_output(output_result)
        tester.check_loss_output(output_result)

        output_result = tester.create_bert_for_question_answering(*config_and_inputs)
        tester.check_bert_for_question_answering_output(output_result)
        tester.check_loss_output(output_result)

        output_result = tester.create_bert_for_sequence_classification(*config_and_inputs)
        tester.check_bert_for_sequence_classification_output(output_result)
        tester.check_loss_output(output_result)

        output_result = tester.create_bert_for_token_classification(*config_and_inputs)
        tester.check_bert_for_token_classification_output(output_result)
        tester.check_loss_output(output_result)
431

432
433
434
435
436
        output_result = tester.create_bert_for_multiple_choice(*config_and_inputs)
        tester.check_bert_for_multiple_choice(output_result)
        tester.check_loss_output(output_result)

        tester.create_and_check_bert_for_attentions(*config_and_inputs)
thomwolf's avatar
thomwolf committed
437
        tester.create_and_check_bert_for_headmasking(*config_and_inputs)
thomwolf's avatar
thomwolf committed
438
        tester.create_and_check_bert_for_head_pruning(*config_and_inputs)
439

440
441
442
443
444
445
446
447
448
449
450
451
452
453
    @classmethod
    def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
        """Creates a random int32 tensor of the shape within the vocab size."""
        if rng is None:
            rng = random.Random()

        total_dims = 1
        for dim in shape:
            total_dims *= dim

        values = []
        for _ in range(total_dims):
            values.append(rng.randint(0, vocab_size - 1))

thomwolf's avatar
thomwolf committed
454
        return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
thomwolf's avatar
thomwolf committed
455
456
457


if __name__ == "__main__":
458
    unittest.main()