Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3039634d
Commit
3039634d
authored
Nov 24, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Nov 24, 2020
Browse files
Add get_config() methods for mobile_bert_layers.
PiperOrigin-RevId: 344190250
parent
0ff25f6b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
5 deletions
+76
-5
official/nlp/modeling/layers/mobile_bert_layers.py
official/nlp/modeling/layers/mobile_bert_layers.py
+42
-5
official/nlp/modeling/layers/mobile_bert_layers_test.py
official/nlp/modeling/layers/mobile_bert_layers_test.py
+34
-0
No files found.
official/nlp/modeling/layers/mobile_bert_layers.py
View file @
3039634d
...
@@ -76,7 +76,8 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
...
@@ -76,7 +76,8 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
max_sequence_length
=
512
,
max_sequence_length
=
512
,
normalization_type
=
'no_norm'
,
normalization_type
=
'no_norm'
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
dropout_rate
=
0.1
):
dropout_rate
=
0.1
,
**
kwargs
):
"""Class initialization.
"""Class initialization.
Arguments:
Arguments:
...
@@ -90,13 +91,16 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
...
@@ -90,13 +91,16 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
initializer: The initializer to use for the embedding weights and
initializer: The initializer to use for the embedding weights and
linear projection weights.
linear projection weights.
dropout_rate: Dropout rate.
dropout_rate: Dropout rate.
**kwargs: keyword arguments.
"""
"""
super
(
MobileBertEmbedding
,
self
).
__init__
()
super
(
MobileBertEmbedding
,
self
).
__init__
(
**
kwargs
)
self
.
word_vocab_size
=
word_vocab_size
self
.
word_vocab_size
=
word_vocab_size
self
.
word_embed_size
=
word_embed_size
self
.
word_embed_size
=
word_embed_size
self
.
type_vocab_size
=
type_vocab_size
self
.
type_vocab_size
=
type_vocab_size
self
.
output_embed_size
=
output_embed_size
self
.
output_embed_size
=
output_embed_size
self
.
max_sequence_length
=
max_sequence_length
self
.
max_sequence_length
=
max_sequence_length
self
.
normalization_type
=
normalization_type
self
.
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
self
.
dropout_rate
=
dropout_rate
self
.
dropout_rate
=
dropout_rate
self
.
word_embedding
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
self
.
word_embedding
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
...
@@ -125,6 +129,20 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
...
@@ -125,6 +129,20 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self
.
dropout_rate
,
self
.
dropout_rate
,
name
=
'embedding_dropout'
)
name
=
'embedding_dropout'
)
def
get_config
(
self
):
config
=
{
'word_vocab_size'
:
self
.
word_vocab_size
,
'word_embed_size'
:
self
.
word_embed_size
,
'type_vocab_size'
:
self
.
type_vocab_size
,
'output_embed_size'
:
self
.
output_embed_size
,
'max_sequence_length'
:
self
.
max_sequence_length
,
'normalization_type'
:
self
.
normalization_type
,
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
self
.
initializer
),
'dropout_rate'
:
self
.
dropout_rate
}
base_config
=
super
(
MobileBertEmbedding
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
input_ids
,
token_type_ids
=
None
):
def
call
(
self
,
input_ids
,
token_type_ids
=
None
):
word_embedding_out
=
self
.
word_embedding
(
input_ids
)
word_embedding_out
=
self
.
word_embedding
(
input_ids
)
word_embedding_out
=
tf
.
concat
(
word_embedding_out
=
tf
.
concat
(
...
@@ -168,7 +186,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
...
@@ -168,7 +186,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
num_feedforward_networks
=
4
,
num_feedforward_networks
=
4
,
normalization_type
=
'no_norm'
,
normalization_type
=
'no_norm'
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
name
=
None
):
**
kwargs
):
"""Class initialization.
"""Class initialization.
Arguments:
Arguments:
...
@@ -194,12 +212,12 @@ class MobileBertTransformer(tf.keras.layers.Layer):
...
@@ -194,12 +212,12 @@ class MobileBertTransformer(tf.keras.layers.Layer):
original MobileBERT paper. 'layer_norm' is used for the teacher model.
original MobileBERT paper. 'layer_norm' is used for the teacher model.
initializer: The initializer to use for the embedding weights and
initializer: The initializer to use for the embedding weights and
linear projection weights.
linear projection weights.
name: A string represents the layer name
.
**kwargs: keyword arguments
.
Raises:
Raises:
ValueError: A Tensor shape or parameter is invalid.
ValueError: A Tensor shape or parameter is invalid.
"""
"""
super
(
MobileBertTransformer
,
self
).
__init__
(
name
=
name
)
super
(
MobileBertTransformer
,
self
).
__init__
(
**
kwargs
)
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
intermediate_size
=
intermediate_size
...
@@ -211,6 +229,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
...
@@ -211,6 +229,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
self
.
key_query_shared_bottleneck
=
key_query_shared_bottleneck
self
.
key_query_shared_bottleneck
=
key_query_shared_bottleneck
self
.
num_feedforward_networks
=
num_feedforward_networks
self
.
num_feedforward_networks
=
num_feedforward_networks
self
.
normalization_type
=
normalization_type
self
.
normalization_type
=
normalization_type
self
.
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
if
intra_bottleneck_size
%
num_attention_heads
!=
0
:
if
intra_bottleneck_size
%
num_attention_heads
!=
0
:
raise
ValueError
(
raise
ValueError
(
...
@@ -300,6 +319,24 @@ class MobileBertTransformer(tf.keras.layers.Layer):
...
@@ -300,6 +319,24 @@ class MobileBertTransformer(tf.keras.layers.Layer):
dropout_layer
,
dropout_layer
,
layer_norm
]
layer_norm
]
def
get_config
(
self
):
config
=
{
'hidden_size'
:
self
.
hidden_size
,
'num_attention_heads'
:
self
.
num_attention_heads
,
'intermediate_size'
:
self
.
intermediate_size
,
'intermediate_act_fn'
:
self
.
intermediate_act_fn
,
'hidden_dropout_prob'
:
self
.
hidden_dropout_prob
,
'attention_probs_dropout_prob'
:
self
.
attention_probs_dropout_prob
,
'intra_bottleneck_size'
:
self
.
intra_bottleneck_size
,
'use_bottleneck_attention'
:
self
.
use_bottleneck_attention
,
'key_query_shared_bottleneck'
:
self
.
key_query_shared_bottleneck
,
'num_feedforward_networks'
:
self
.
num_feedforward_networks
,
'normalization_type'
:
self
.
normalization_type
,
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
self
.
initializer
),
}
base_config
=
super
(
MobileBertTransformer
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
def
call
(
self
,
input_tensor
,
input_tensor
,
attention_mask
=
None
,
attention_mask
=
None
,
...
...
official/nlp/modeling/layers/mobile_bert_layers_test.py
View file @
3039634d
...
@@ -51,6 +51,20 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -51,6 +51,20 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
expected_shape
=
[
1
,
4
,
16
]
expected_shape
=
[
1
,
4
,
16
]
self
.
assertListEqual
(
output_shape
,
expected_shape
,
msg
=
None
)
self
.
assertListEqual
(
output_shape
,
expected_shape
,
msg
=
None
)
def
test_embedding_layer_get_config
(
self
):
layer
=
mobile_bert_layers
.
MobileBertEmbedding
(
word_vocab_size
=
16
,
word_embed_size
=
32
,
type_vocab_size
=
4
,
output_embed_size
=
32
,
max_sequence_length
=
32
,
normalization_type
=
'layer_norm'
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.01
),
dropout_rate
=
0.5
)
layer_config
=
layer
.
get_config
()
new_layer
=
mobile_bert_layers
.
MobileBertEmbedding
.
from_config
(
layer_config
)
self
.
assertEqual
(
layer_config
,
new_layer
.
get_config
())
def
test_no_norm
(
self
):
def
test_no_norm
(
self
):
layer
=
mobile_bert_layers
.
NoNorm
()
layer
=
mobile_bert_layers
.
NoNorm
()
feature
=
tf
.
random
.
normal
([
2
,
3
,
4
])
feature
=
tf
.
random
.
normal
([
2
,
3
,
4
])
...
@@ -92,6 +106,26 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -92,6 +106,26 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertListEqual
(
self
.
assertListEqual
(
attention_score
.
shape
.
as_list
(),
expected_shape
,
msg
=
None
)
attention_score
.
shape
.
as_list
(),
expected_shape
,
msg
=
None
)
def
test_transformer_get_config
(
self
):
layer
=
mobile_bert_layers
.
MobileBertTransformer
(
hidden_size
=
32
,
num_attention_heads
=
2
,
intermediate_size
=
48
,
intermediate_act_fn
=
'gelu'
,
hidden_dropout_prob
=
0.5
,
attention_probs_dropout_prob
=
0.4
,
intra_bottleneck_size
=
64
,
use_bottleneck_attention
=
True
,
key_query_shared_bottleneck
=
False
,
num_feedforward_networks
=
2
,
normalization_type
=
'layer_norm'
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.01
),
name
=
'block'
)
layer_config
=
layer
.
get_config
()
new_layer
=
mobile_bert_layers
.
MobileBertTransformer
.
from_config
(
layer_config
)
self
.
assertEqual
(
layer_config
,
new_layer
.
get_config
())
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment