modeling_test.py 5.37 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19
import unittest
thomwolf's avatar
thomwolf committed
20
21
22
import json
import random

23
24
import torch

thomwolf's avatar
thomwolf committed
25
from pytorch_pretrained_bert import BertConfig, BertModel
thomwolf's avatar
thomwolf committed
26
27


28
class BertModelTest(unittest.TestCase):
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    class BertModelTester(object):

        def __init__(self,
                     parent,
                     batch_size=13,
                     seq_length=7,
                     is_training=True,
                     use_input_mask=True,
                     use_token_type_ids=True,
                     vocab_size=99,
                     hidden_size=32,
                     num_hidden_layers=5,
                     num_attention_heads=4,
                     intermediate_size=37,
                     hidden_act="gelu",
                     hidden_dropout_prob=0.1,
                     attention_probs_dropout_prob=0.1,
                     max_position_embeddings=512,
                     type_vocab_size=16,
                     initializer_range=0.02,
                     scope=None):
            self.parent = parent
            self.batch_size = batch_size
            self.seq_length = seq_length
            self.is_training = is_training
            self.use_input_mask = use_input_mask
            self.use_token_type_ids = use_token_type_ids
            self.vocab_size = vocab_size
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.intermediate_size = intermediate_size
            self.hidden_act = hidden_act
            self.hidden_dropout_prob = hidden_dropout_prob
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            self.max_position_embeddings = max_position_embeddings
            self.type_vocab_size = type_vocab_size
            self.initializer_range = initializer_range
            self.scope = scope

        def create_model(self):
70
            input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
71
72
73

            input_mask = None
            if self.use_input_mask:
74
                input_mask = BertModelTest.ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
75
76
77

            token_type_ids = None
            if self.use_token_type_ids:
78
                token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
79

thomwolf's avatar
thomwolf committed
80
81
            config = BertConfig(
                vocab_size_or_config_json_file=self.vocab_size,
82
83
84
85
86
87
88
89
90
91
92
                hidden_size=self.hidden_size,
                num_hidden_layers=self.num_hidden_layers,
                num_attention_heads=self.num_attention_heads,
                intermediate_size=self.intermediate_size,
                hidden_act=self.hidden_act,
                hidden_dropout_prob=self.hidden_dropout_prob,
                attention_probs_dropout_prob=self.attention_probs_dropout_prob,
                max_position_embeddings=self.max_position_embeddings,
                type_vocab_size=self.type_vocab_size,
                initializer_range=self.initializer_range)

thomwolf's avatar
thomwolf committed
93
            model = BertModel(config=config)
94
95

            all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
96
97

            outputs = {
98
99
100
                "sequence_output": all_encoder_layers[-1],
                "pooled_output": pooled_output,
                "all_encoder_layers": all_encoder_layers,
101
102
103
104
            }
            return outputs

        def check_output(self, result):
105
106
            self.parent.assertListEqual(
                list(result["sequence_output"].size()),
107
108
                [self.batch_size, self.seq_length, self.hidden_size])

109
            self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
110
111
112
113
114

    def test_default(self):
        self.run_tester(BertModelTest.BertModelTester(self))

    def test_config_to_json_string(self):
thomwolf's avatar
thomwolf committed
115
        config = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37)
116
117
118
119
120
        obj = json.loads(config.to_json_string())
        self.assertEqual(obj["vocab_size"], 99)
        self.assertEqual(obj["hidden_size"], 37)

    def run_tester(self, tester):
121
122
        output_result = tester.create_model()
        tester.check_output(output_result)
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

    @classmethod
    def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
        """Creates a random int32 tensor of the shape within the vocab size."""
        if rng is None:
            rng = random.Random()

        total_dims = 1
        for dim in shape:
            total_dims *= dim

        values = []
        for _ in range(total_dims):
            values.append(rng.randint(0, vocab_size - 1))

thomwolf's avatar
thomwolf committed
138
        return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
thomwolf's avatar
thomwolf committed
139
140
141


if __name__ == "__main__":
142
    unittest.main()