test_modeling_flax_bert.py 5.8 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
import unittest

17
18
import numpy as np

Sylvain Gugger's avatar
Sylvain Gugger committed
19
from transformers import BertConfig, is_flax_available
20
from transformers.testing_utils import require_flax, slow
21

Yih-Dar's avatar
Yih-Dar committed
22
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
23
24
25


if is_flax_available():
26
27
28
29
30
31
    from transformers.models.bert.modeling_flax_bert import (
        FlaxBertForMaskedLM,
        FlaxBertForMultipleChoice,
        FlaxBertForNextSentencePrediction,
        FlaxBertForPreTraining,
        FlaxBertForQuestionAnswering,
32
        FlaxBertForSequenceClassification,
33
34
35
        FlaxBertForTokenClassification,
        FlaxBertModel,
    )
36
37


Sylvain Gugger's avatar
Sylvain Gugger committed
38
39
40
41
42
43
44
45
46
47
48
49
class FlaxBertModelTester(unittest.TestCase):
    def __init__(
        self,
        parent,
        batch_size=13,
        seq_length=7,
        is_training=True,
        use_attention_mask=True,
        use_token_type_ids=True,
        use_labels=True,
        vocab_size=99,
        hidden_size=32,
50
        num_hidden_layers=2,
Sylvain Gugger's avatar
Sylvain Gugger committed
51
52
53
54
55
56
57
58
59
        num_attention_heads=4,
        intermediate_size=37,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=16,
        type_sequence_label_size=2,
        initializer_range=0.02,
60
        num_choices=4,
Sylvain Gugger's avatar
Sylvain Gugger committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    ):
        self.parent = parent
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.is_training = is_training
        self.use_attention_mask = use_attention_mask
        self.use_token_type_ids = use_token_type_ids
        self.use_labels = use_labels
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.type_sequence_label_size = type_sequence_label_size
        self.initializer_range = initializer_range
81
        self.num_choices = num_choices
Sylvain Gugger's avatar
Sylvain Gugger committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

    def prepare_config_and_inputs(self):
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

        attention_mask = None
        if self.use_attention_mask:
            attention_mask = random_attention_mask([self.batch_size, self.seq_length])

        token_type_ids = None
        if self.use_token_type_ids:
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)

        config = BertConfig(
            vocab_size=self.vocab_size,
            hidden_size=self.hidden_size,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            intermediate_size=self.intermediate_size,
            hidden_act=self.hidden_act,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            type_vocab_size=self.type_vocab_size,
            is_decoder=False,
            initializer_range=self.initializer_range,
        )

        return config, input_ids, token_type_ids, attention_mask

    def prepare_config_and_inputs_for_common(self):
        config_and_inputs = self.prepare_config_and_inputs()
        config, input_ids, token_type_ids, attention_mask = config_and_inputs
        inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
        return config, inputs_dict
116

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    def prepare_config_and_inputs_for_decoder(self):
        config_and_inputs = self.prepare_config_and_inputs()
        config, input_ids, token_type_ids, attention_mask = config_and_inputs

        config.is_decoder = True
        encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
        encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)

        return (
            config,
            input_ids,
            attention_mask,
            encoder_hidden_states,
            encoder_attention_mask,
        )

133
134

@require_flax
Sylvain Gugger's avatar
Sylvain Gugger committed
135
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
136
137
    test_head_masking = True

138
139
140
141
142
143
144
145
    all_model_classes = (
        (
            FlaxBertModel,
            FlaxBertForPreTraining,
            FlaxBertForMaskedLM,
            FlaxBertForMultipleChoice,
            FlaxBertForQuestionAnswering,
            FlaxBertForNextSentencePrediction,
146
            FlaxBertForSequenceClassification,
147
148
149
150
151
152
            FlaxBertForTokenClassification,
            FlaxBertForQuestionAnswering,
        )
        if is_flax_available()
        else ()
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
153
154
155

    def setUp(self):
        self.model_tester = FlaxBertModelTester(self)
156
157
158

    @slow
    def test_model_from_pretrained(self):
159
160
        # Only check this for base model, not necessary for all model classes.
        # This will also help speed-up tests.
161
        model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
162
163
        outputs = model(np.ones((1, 1)))
        self.assertIsNotNone(outputs)