Commit 3039634d authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Add get_config() methods for mobile_bert_layers.

PiperOrigin-RevId: 344190250
parent 0ff25f6b
...@@ -76,7 +76,8 @@ class MobileBertEmbedding(tf.keras.layers.Layer): ...@@ -76,7 +76,8 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
max_sequence_length=512, max_sequence_length=512,
normalization_type='no_norm', normalization_type='no_norm',
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
dropout_rate=0.1): dropout_rate=0.1,
**kwargs):
"""Class initialization. """Class initialization.
Arguments: Arguments:
...@@ -90,13 +91,16 @@ class MobileBertEmbedding(tf.keras.layers.Layer): ...@@ -90,13 +91,16 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
initializer: The initializer to use for the embedding weights and initializer: The initializer to use for the embedding weights and
linear projection weights. linear projection weights.
dropout_rate: Dropout rate. dropout_rate: Dropout rate.
**kwargs: keyword arguments.
""" """
super(MobileBertEmbedding, self).__init__() super(MobileBertEmbedding, self).__init__(**kwargs)
self.word_vocab_size = word_vocab_size self.word_vocab_size = word_vocab_size
self.word_embed_size = word_embed_size self.word_embed_size = word_embed_size
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.output_embed_size = output_embed_size self.output_embed_size = output_embed_size
self.max_sequence_length = max_sequence_length self.max_sequence_length = max_sequence_length
self.normalization_type = normalization_type
self.initializer = tf.keras.initializers.get(initializer)
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.word_embedding = keras_nlp.layers.OnDeviceEmbedding( self.word_embedding = keras_nlp.layers.OnDeviceEmbedding(
...@@ -125,6 +129,20 @@ class MobileBertEmbedding(tf.keras.layers.Layer): ...@@ -125,6 +129,20 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self.dropout_rate, self.dropout_rate,
name='embedding_dropout') name='embedding_dropout')
def get_config(self):
config = {
'word_vocab_size': self.word_vocab_size,
'word_embed_size': self.word_embed_size,
'type_vocab_size': self.type_vocab_size,
'output_embed_size': self.output_embed_size,
'max_sequence_length': self.max_sequence_length,
'normalization_type': self.normalization_type,
'initializer': tf.keras.initializers.serialize(self.initializer),
'dropout_rate': self.dropout_rate
}
base_config = super(MobileBertEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids, token_type_ids=None): def call(self, input_ids, token_type_ids=None):
word_embedding_out = self.word_embedding(input_ids) word_embedding_out = self.word_embedding(input_ids)
word_embedding_out = tf.concat( word_embedding_out = tf.concat(
...@@ -168,7 +186,7 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -168,7 +186,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
num_feedforward_networks=4, num_feedforward_networks=4,
normalization_type='no_norm', normalization_type='no_norm',
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
name=None): **kwargs):
"""Class initialization. """Class initialization.
Arguments: Arguments:
...@@ -194,12 +212,12 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -194,12 +212,12 @@ class MobileBertTransformer(tf.keras.layers.Layer):
original MobileBERT paper. 'layer_norm' is used for the teacher model. original MobileBERT paper. 'layer_norm' is used for the teacher model.
initializer: The initializer to use for the embedding weights and initializer: The initializer to use for the embedding weights and
linear projection weights. linear projection weights.
name: A string represents the layer name. **kwargs: keyword arguments.
Raises: Raises:
ValueError: A Tensor shape or parameter is invalid. ValueError: A Tensor shape or parameter is invalid.
""" """
super(MobileBertTransformer, self).__init__(name=name) super(MobileBertTransformer, self).__init__(**kwargs)
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
...@@ -211,6 +229,7 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -211,6 +229,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
self.key_query_shared_bottleneck = key_query_shared_bottleneck self.key_query_shared_bottleneck = key_query_shared_bottleneck
self.num_feedforward_networks = num_feedforward_networks self.num_feedforward_networks = num_feedforward_networks
self.normalization_type = normalization_type self.normalization_type = normalization_type
self.initializer = tf.keras.initializers.get(initializer)
if intra_bottleneck_size % num_attention_heads != 0: if intra_bottleneck_size % num_attention_heads != 0:
raise ValueError( raise ValueError(
...@@ -300,6 +319,24 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -300,6 +319,24 @@ class MobileBertTransformer(tf.keras.layers.Layer):
dropout_layer, dropout_layer,
layer_norm] layer_norm]
def get_config(self):
config = {
'hidden_size': self.hidden_size,
'num_attention_heads': self.num_attention_heads,
'intermediate_size': self.intermediate_size,
'intermediate_act_fn': self.intermediate_act_fn,
'hidden_dropout_prob': self.hidden_dropout_prob,
'attention_probs_dropout_prob': self.attention_probs_dropout_prob,
'intra_bottleneck_size': self.intra_bottleneck_size,
'use_bottleneck_attention': self.use_bottleneck_attention,
'key_query_shared_bottleneck': self.key_query_shared_bottleneck,
'num_feedforward_networks': self.num_feedforward_networks,
'normalization_type': self.normalization_type,
'initializer': tf.keras.initializers.serialize(self.initializer),
}
base_config = super(MobileBertTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, def call(self,
input_tensor, input_tensor,
attention_mask=None, attention_mask=None,
......
...@@ -51,6 +51,20 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -51,6 +51,20 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
expected_shape = [1, 4, 16] expected_shape = [1, 4, 16]
self.assertListEqual(output_shape, expected_shape, msg=None) self.assertListEqual(output_shape, expected_shape, msg=None)
def test_embedding_layer_get_config(self):
layer = mobile_bert_layers.MobileBertEmbedding(
word_vocab_size=16,
word_embed_size=32,
type_vocab_size=4,
output_embed_size=32,
max_sequence_length=32,
normalization_type='layer_norm',
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.01),
dropout_rate=0.5)
layer_config = layer.get_config()
new_layer = mobile_bert_layers.MobileBertEmbedding.from_config(layer_config)
self.assertEqual(layer_config, new_layer.get_config())
def test_no_norm(self): def test_no_norm(self):
layer = mobile_bert_layers.NoNorm() layer = mobile_bert_layers.NoNorm()
feature = tf.random.normal([2, 3, 4]) feature = tf.random.normal([2, 3, 4])
...@@ -92,6 +106,26 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -92,6 +106,26 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertListEqual( self.assertListEqual(
attention_score.shape.as_list(), expected_shape, msg=None) attention_score.shape.as_list(), expected_shape, msg=None)
def test_transformer_get_config(self):
layer = mobile_bert_layers.MobileBertTransformer(
hidden_size=32,
num_attention_heads=2,
intermediate_size=48,
intermediate_act_fn='gelu',
hidden_dropout_prob=0.5,
attention_probs_dropout_prob=0.4,
intra_bottleneck_size=64,
use_bottleneck_attention=True,
key_query_shared_bottleneck=False,
num_feedforward_networks=2,
normalization_type='layer_norm',
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.01),
name='block')
layer_config = layer.get_config()
new_layer = mobile_bert_layers.MobileBertTransformer.from_config(
layer_config)
self.assertEqual(layer_config, new_layer.get_config())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment