Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
87d7e974
Commit
87d7e974
authored
Dec 10, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Dec 10, 2020
Browse files
Move MobileBertMaskedLM to official/nlp/modeling/layers/mobile_bert_layers.py
PiperOrigin-RevId: 346826537
parent
a4b767b7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
272 additions
and
0 deletions
+272
-0
official/nlp/modeling/layers/__init__.py
official/nlp/modeling/layers/__init__.py
+1
-0
official/nlp/modeling/layers/mobile_bert_layers.py
official/nlp/modeling/layers/mobile_bert_layers.py
+129
-0
official/nlp/modeling/layers/mobile_bert_layers_test.py
official/nlp/modeling/layers/mobile_bert_layers_test.py
+142
-0
No files found.
official/nlp/modeling/layers/__init__.py
View file @
87d7e974
...
...
@@ -22,6 +22,7 @@ from official.nlp.modeling.layers.masked_lm import MaskedLM
from
official.nlp.modeling.layers.masked_softmax
import
MaskedSoftmax
from
official.nlp.modeling.layers.mat_mul_with_margin
import
MatMulWithMargin
from
official.nlp.modeling.layers.mobile_bert_layers
import
MobileBertEmbedding
from
official.nlp.modeling.layers.mobile_bert_layers
import
MobileBertMaskedLM
from
official.nlp.modeling.layers.mobile_bert_layers
import
MobileBertTransformer
from
official.nlp.modeling.layers.multi_channel_attention
import
*
from
official.nlp.modeling.layers.on_device_embedding
import
OnDeviceEmbedding
...
...
official/nlp/modeling/layers/mobile_bert_layers.py
View file @
87d7e974
...
...
@@ -18,6 +18,7 @@ import tensorflow as tf
from
official.nlp
import
keras_nlp
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
NoNorm
(
tf
.
keras
.
layers
.
Layer
):
"""Apply element-wise linear transformation to the last dimension."""
...
...
@@ -62,6 +63,7 @@ def _get_norm_layer(normalization_type='no_norm', name=None):
return
layer
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
MobileBertEmbedding
(
tf
.
keras
.
layers
.
Layer
):
"""Performs an embedding lookup for MobileBERT.
...
...
@@ -163,6 +165,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
return
embedding_out
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
MobileBertTransformer
(
tf
.
keras
.
layers
.
Layer
):
"""Transformer block for MobileBERT.
...
...
@@ -422,3 +425,129 @@ class MobileBertTransformer(tf.keras.layers.Layer):
return
layer_output
,
attention_scores
else
:
return
layer_output
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
MobileBertMaskedLM
(
tf
.
keras
.
layers
.
Layer
):
"""Masked language model network head for BERT modeling.
This layer implements a masked language model based on the provided
transformer based encoder. It assumes that the encoder network being passed
has a "get_embedding_table()" method. Different from canonical BERT's masked
LM layer, when the embedding width is smaller than hidden_size, it adds an
extra output weights in shape [vocab_size, (hidden_size - embedding_width)].
"""
def
__init__
(
self
,
embedding_table
,
activation
=
None
,
initializer
=
'glorot_uniform'
,
output
=
'logits'
,
**
kwargs
):
"""Class initialization.
Arguments:
embedding_table: The embedding table from encoder network.
activation: The activation, if any, for the dense layer.
initializer: The initializer for the dense layer. Defaults to a Glorot
uniform initializer.
output: The output style for this layer. Can be either 'logits' or
'predictions'.
**kwargs: keyword arguments.
"""
super
(
MobileBertMaskedLM
,
self
).
__init__
(
**
kwargs
)
self
.
embedding_table
=
embedding_table
self
.
activation
=
activation
self
.
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
if
output
not
in
(
'predictions'
,
'logits'
):
raise
ValueError
(
(
'Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"'
)
%
output
)
self
.
_output_type
=
output
def
build
(
self
,
input_shape
):
self
.
_vocab_size
,
embedding_width
=
self
.
embedding_table
.
shape
hidden_size
=
input_shape
[
-
1
]
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
hidden_size
,
activation
=
self
.
activation
,
kernel_initializer
=
self
.
initializer
,
name
=
'transform/dense'
)
if
hidden_size
>
embedding_width
:
self
.
extra_output_weights
=
self
.
add_weight
(
'extra_output_weights'
,
shape
=
(
self
.
_vocab_size
,
hidden_size
-
embedding_width
),
initializer
=
self
.
initializer
,
trainable
=
True
)
elif
hidden_size
==
embedding_width
:
self
.
extra_output_weights
=
None
else
:
raise
ValueError
(
'hidden size %d cannot be smaller than embedding width %d.'
%
(
hidden_size
,
embedding_width
))
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
axis
=-
1
,
epsilon
=
1e-12
,
name
=
'transform/LayerNorm'
)
self
.
bias
=
self
.
add_weight
(
'output_bias/bias'
,
shape
=
(
self
.
_vocab_size
,),
initializer
=
'zeros'
,
trainable
=
True
)
super
(
MobileBertMaskedLM
,
self
).
build
(
input_shape
)
def
call
(
self
,
sequence_data
,
masked_positions
):
masked_lm_input
=
self
.
_gather_indexes
(
sequence_data
,
masked_positions
)
lm_data
=
self
.
dense
(
masked_lm_input
)
lm_data
=
self
.
layer_norm
(
lm_data
)
if
self
.
extra_output_weights
is
None
:
lm_data
=
tf
.
matmul
(
lm_data
,
self
.
embedding_table
,
transpose_b
=
True
)
else
:
lm_data
=
tf
.
matmul
(
lm_data
,
tf
.
concat
([
self
.
embedding_table
,
self
.
extra_output_weights
],
axis
=
1
),
transpose_b
=
True
)
logits
=
tf
.
nn
.
bias_add
(
lm_data
,
self
.
bias
)
masked_positions_length
=
masked_positions
.
shape
.
as_list
()[
1
]
or
tf
.
shape
(
masked_positions
)[
1
]
logits
=
tf
.
reshape
(
logits
,
[
-
1
,
masked_positions_length
,
self
.
_vocab_size
])
if
self
.
_output_type
==
'logits'
:
return
logits
return
tf
.
nn
.
log_softmax
(
logits
)
def
get_config
(
self
):
raise
NotImplementedError
(
'MaskedLM cannot be directly serialized because '
'it has variable sharing logic.'
)
def
_gather_indexes
(
self
,
sequence_tensor
,
positions
):
"""Gathers the vectors at the specific positions.
Args:
sequence_tensor: Sequence output of `BertModel` layer of shape
(`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
hidden units of `BertModel` layer.
positions: Positions ids of tokens in sequence to mask for pretraining
of with dimension (batch_size, num_predictions) where
`num_predictions` is maximum number of tokens to mask out and predict
per each sequence.
Returns:
Masked out sequence tensor of shape (batch_size * num_predictions,
num_hidden).
"""
sequence_shape
=
tf
.
shape
(
sequence_tensor
)
batch_size
,
seq_length
=
sequence_shape
[
0
],
sequence_shape
[
1
]
width
=
sequence_tensor
.
shape
.
as_list
()[
2
]
or
sequence_shape
[
2
]
flat_offsets
=
tf
.
reshape
(
tf
.
range
(
0
,
batch_size
,
dtype
=
tf
.
int32
)
*
seq_length
,
[
-
1
,
1
])
flat_positions
=
tf
.
reshape
(
positions
+
flat_offsets
,
[
-
1
])
flat_sequence_tensor
=
tf
.
reshape
(
sequence_tensor
,
[
batch_size
*
seq_length
,
width
])
output_tensor
=
tf
.
gather
(
flat_sequence_tensor
,
flat_positions
)
return
output_tensor
official/nlp/modeling/layers/mobile_bert_layers_test.py
View file @
87d7e974
...
...
@@ -18,6 +18,7 @@ import numpy as np
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
mobile_bert_layers
from
official.nlp.modeling.networks
import
mobile_bert_encoder
def
generate_fake_input
(
batch_size
=
1
,
seq_len
=
5
,
vocab_size
=
10000
,
seed
=
0
):
...
...
@@ -127,5 +128,146 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertEqual
(
layer_config
,
new_layer
.
get_config
())
class
MobileBertMaskedLMTest
(
tf
.
test
.
TestCase
):
def
create_layer
(
self
,
vocab_size
,
hidden_size
,
embedding_width
,
output
=
'predictions'
,
xformer_stack
=
None
):
# First, create a transformer stack that we can use to get the LM's
# vocabulary weight.
if
xformer_stack
is
None
:
xformer_stack
=
mobile_bert_encoder
.
MobileBERTEncoder
(
word_vocab_size
=
vocab_size
,
num_blocks
=
1
,
hidden_size
=
hidden_size
,
num_attention_heads
=
4
,
word_embed_size
=
embedding_width
)
# Create a maskedLM from the transformer stack.
test_layer
=
mobile_bert_layers
.
MobileBertMaskedLM
(
embedding_table
=
xformer_stack
.
get_embedding_table
(),
output
=
output
)
return
test_layer
def
test_layer_creation
(
self
):
vocab_size
=
100
sequence_length
=
32
hidden_size
=
64
embedding_width
=
32
num_predictions
=
21
test_layer
=
self
.
create_layer
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
embedding_width
=
embedding_width
)
# Make sure that the output tensor of the masked LM is the right shape.
lm_input_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
hidden_size
))
masked_positions
=
tf
.
keras
.
Input
(
shape
=
(
num_predictions
,),
dtype
=
tf
.
int32
)
output
=
test_layer
(
lm_input_tensor
,
masked_positions
=
masked_positions
)
expected_output_shape
=
[
None
,
num_predictions
,
vocab_size
]
self
.
assertEqual
(
expected_output_shape
,
output
.
shape
.
as_list
())
def
test_layer_invocation_with_external_logits
(
self
):
vocab_size
=
100
sequence_length
=
32
hidden_size
=
64
embedding_width
=
32
num_predictions
=
21
xformer_stack
=
mobile_bert_encoder
.
MobileBERTEncoder
(
word_vocab_size
=
vocab_size
,
num_blocks
=
1
,
hidden_size
=
hidden_size
,
num_attention_heads
=
4
,
word_embed_size
=
embedding_width
)
test_layer
=
self
.
create_layer
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
embedding_width
=
embedding_width
,
xformer_stack
=
xformer_stack
,
output
=
'predictions'
)
logit_layer
=
self
.
create_layer
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
embedding_width
=
embedding_width
,
xformer_stack
=
xformer_stack
,
output
=
'logits'
)
# Create a model from the masked LM layer.
lm_input_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
hidden_size
))
masked_positions
=
tf
.
keras
.
Input
(
shape
=
(
num_predictions
,),
dtype
=
tf
.
int32
)
output
=
test_layer
(
lm_input_tensor
,
masked_positions
)
logit_output
=
logit_layer
(
lm_input_tensor
,
masked_positions
)
logit_output
=
tf
.
keras
.
layers
.
Activation
(
tf
.
nn
.
log_softmax
)(
logit_output
)
logit_layer
.
set_weights
(
test_layer
.
get_weights
())
model
=
tf
.
keras
.
Model
([
lm_input_tensor
,
masked_positions
],
output
)
logits_model
=
tf
.
keras
.
Model
(([
lm_input_tensor
,
masked_positions
]),
logit_output
)
# Invoke the masked LM on some fake data to make sure there are no runtime
# errors in the code.
batch_size
=
3
lm_input_data
=
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
hidden_size
))
masked_position_data
=
np
.
random
.
randint
(
sequence_length
,
size
=
(
batch_size
,
num_predictions
))
# ref_outputs = model.predict([lm_input_data, masked_position_data])
# outputs = logits_model.predict([lm_input_data, masked_position_data])
ref_outputs
=
model
([
lm_input_data
,
masked_position_data
])
outputs
=
logits_model
([
lm_input_data
,
masked_position_data
])
# Ensure that the tensor shapes are correct.
expected_output_shape
=
(
batch_size
,
num_predictions
,
vocab_size
)
self
.
assertEqual
(
expected_output_shape
,
ref_outputs
.
shape
)
self
.
assertEqual
(
expected_output_shape
,
outputs
.
shape
)
self
.
assertAllClose
(
ref_outputs
,
outputs
)
def
test_layer_invocation
(
self
):
vocab_size
=
100
sequence_length
=
32
hidden_size
=
64
embedding_width
=
32
num_predictions
=
21
test_layer
=
self
.
create_layer
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
embedding_width
=
embedding_width
)
# Create a model from the masked LM layer.
lm_input_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
hidden_size
))
masked_positions
=
tf
.
keras
.
Input
(
shape
=
(
num_predictions
,),
dtype
=
tf
.
int32
)
output
=
test_layer
(
lm_input_tensor
,
masked_positions
)
model
=
tf
.
keras
.
Model
([
lm_input_tensor
,
masked_positions
],
output
)
# Invoke the masked LM on some fake data to make sure there are no runtime
# errors in the code.
batch_size
=
3
lm_input_data
=
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
hidden_size
))
masked_position_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
num_predictions
))
_
=
model
.
predict
([
lm_input_data
,
masked_position_data
])
def
test_unknown_output_type_fails
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
'Unknown `output` value "bad".*'
):
_
=
self
.
create_layer
(
vocab_size
=
8
,
hidden_size
=
8
,
embedding_width
=
4
,
output
=
'bad'
)
def
test_hidden_size_smaller_than_embedding_width
(
self
):
hidden_size
=
8
sequence_length
=
32
num_predictions
=
20
with
self
.
assertRaisesRegex
(
ValueError
,
'hidden size 8 cannot be smaller than embedding width 16.'
):
test_layer
=
self
.
create_layer
(
vocab_size
=
8
,
hidden_size
=
8
,
embedding_width
=
16
)
lm_input_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
hidden_size
))
masked_positions
=
tf
.
keras
.
Input
(
shape
=
(
num_predictions
,),
dtype
=
tf
.
int32
)
_
=
test_layer
(
lm_input_tensor
,
masked_positions
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment