Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
067e8ae3
Commit
067e8ae3
authored
Sep 16, 2020
by
Zhenyu Tan
Committed by
A. Unique TensorFlower
Sep 16, 2020
Browse files
Add BertEncoder to keras_nlp.
PiperOrigin-RevId: 332122408
parent
966da7c2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
476 additions
and
0 deletions
+476
-0
official/nlp/keras_nlp/encoders/__init__.py
official/nlp/keras_nlp/encoders/__init__.py
+16
-0
official/nlp/keras_nlp/encoders/bert_encoder.py
official/nlp/keras_nlp/encoders/bert_encoder.py
+226
-0
official/nlp/keras_nlp/encoders/bert_encoder_test.py
official/nlp/keras_nlp/encoders/bert_encoder_test.py
+234
-0
No files found.
official/nlp/keras_nlp/encoders/__init__.py
0 → 100644
View file @
067e8ae3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-NLP layers package definition."""
from
official.nlp.keras_nlp.encoders.bert_encoder
import
BertEncoder
official/nlp/keras_nlp/encoders/bert_encoder.py
0 → 100644
View file @
067e8ae3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bert encoder network."""
# pylint: disable=g-classes-have-attributes
import
tensorflow
as
tf
from
official.nlp.keras_nlp
import
layers
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'keras_nlp'
)
class
BertEncoder
(
tf
.
keras
.
Model
):
"""Bi-directional Transformer-based encoder network.
This network implements a bi-directional Transformer-based encoder as
described in "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the
embedding lookups and transformer layers, but not the masked language model
or classification task networks.
The default values for this object are taken from the BERT-Base implementation
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding".
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers
within the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
return_all_encoder_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yeilds the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
"""
def
__init__
(
self
,
vocab_size
,
hidden_size
=
768
,
num_layers
=
12
,
num_attention_heads
=
12
,
max_sequence_length
=
512
,
type_vocab_size
=
16
,
inner_dim
=
3072
,
inner_activation
=
lambda
x
:
tf
.
keras
.
activations
.
gelu
(
x
,
approximate
=
True
),
output_dropout
=
0.1
,
attention_dropout
=
0.1
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
return_all_encoder_outputs
=
False
,
output_range
=
None
,
embedding_width
=
None
,
**
kwargs
):
activation
=
tf
.
keras
.
activations
.
get
(
inner_activation
)
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
self
.
_self_setattr_tracking
=
False
self
.
_config_dict
=
{
'vocab_size'
:
vocab_size
,
'hidden_size'
:
hidden_size
,
'num_layers'
:
num_layers
,
'num_attention_heads'
:
num_attention_heads
,
'max_sequence_length'
:
max_sequence_length
,
'type_vocab_size'
:
type_vocab_size
,
'inner_dim'
:
inner_dim
,
'inner_activation'
:
tf
.
keras
.
activations
.
serialize
(
activation
),
'output_dropout'
:
output_dropout
,
'attention_dropout'
:
attention_dropout
,
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
initializer
),
'return_all_encoder_outputs'
:
return_all_encoder_outputs
,
'output_range'
:
output_range
,
'embedding_width'
:
embedding_width
,
}
word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
)
mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
if
embedding_width
is
None
:
embedding_width
=
hidden_size
self
.
_embedding_layer
=
self
.
_build_embedding_layer
()
word_embeddings
=
self
.
_embedding_layer
(
word_ids
)
# Always uses dynamic slicing for simplicity.
self
.
_position_embedding_layer
=
layers
.
PositionEmbedding
(
initializer
=
initializer
,
max_length
=
max_sequence_length
,
name
=
'position_embedding'
)
position_embeddings
=
self
.
_position_embedding_layer
(
word_embeddings
)
self
.
_type_embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
type_vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
use_one_hot
=
True
,
name
=
'type_embeddings'
)
type_embeddings
=
self
.
_type_embedding_layer
(
type_ids
)
embeddings
=
tf
.
keras
.
layers
.
Add
()(
[
word_embeddings
,
position_embeddings
,
type_embeddings
])
self
.
_embedding_norm_layer
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
'embeddings/layer_norm'
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
embeddings
=
self
.
_embedding_norm_layer
(
embeddings
)
embeddings
=
(
tf
.
keras
.
layers
.
Dropout
(
rate
=
output_dropout
)(
embeddings
))
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if
embedding_width
!=
hidden_size
:
self
.
_embedding_projection
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
'...x,xy->...y'
,
output_shape
=
hidden_size
,
bias_axes
=
'y'
,
kernel_initializer
=
initializer
,
name
=
'embedding_projection'
)
embeddings
=
self
.
_embedding_projection
(
embeddings
)
self
.
_transformer_layers
=
[]
data
=
embeddings
attention_mask
=
layers
.
SelfAttentionMask
()(
data
,
mask
)
encoder_outputs
=
[]
for
i
in
range
(
num_layers
):
if
i
==
num_layers
-
1
and
output_range
is
not
None
:
transformer_output_range
=
output_range
else
:
transformer_output_range
=
None
layer
=
layers
.
TransformerEncoderBlock
(
num_attention_heads
=
num_attention_heads
,
inner_dim
=
inner_dim
,
inner_activation
=
inner_activation
,
output_dropout
=
output_dropout
,
attention_dropout
=
attention_dropout
,
output_range
=
transformer_output_range
,
kernel_initializer
=
initializer
,
name
=
'transformer/layer_%d'
%
i
)
self
.
_transformer_layers
.
append
(
layer
)
data
=
layer
([
data
,
attention_mask
])
encoder_outputs
.
append
(
data
)
first_token_tensor
=
(
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
tf
.
squeeze
(
x
[:,
0
:
1
,
:],
axis
=
1
))(
encoder_outputs
[
-
1
]))
self
.
_pooler_layer
=
tf
.
keras
.
layers
.
Dense
(
units
=
hidden_size
,
activation
=
'tanh'
,
kernel_initializer
=
initializer
,
name
=
'pooler_transform'
)
cls_output
=
self
.
_pooler_layer
(
first_token_tensor
)
outputs
=
dict
(
sequence_output
=
encoder_outputs
[
-
1
],
pooled_output
=
cls_output
,
encoder_outputs
=
encoder_outputs
,
)
super
(
BertEncoder
,
self
).
__init__
(
inputs
=
[
word_ids
,
mask
,
type_ids
],
outputs
=
outputs
,
**
kwargs
)
def
get_embedding_table
(
self
):
return
self
.
_embedding_layer
.
embeddings
def
_build_embedding_layer
(
self
):
embedding_width
=
self
.
_config_dict
[
'embedding_width'
]
or
self
.
_config_dict
[
'hidden_size'
]
return
layers
.
OnDeviceEmbedding
(
vocab_size
=
self
.
_config_dict
[
'vocab_size'
],
embedding_width
=
embedding_width
,
initializer
=
self
.
_config_dict
[
'initializer'
],
name
=
'word_embeddings'
)
def
get_embedding_layer
(
self
):
return
self
.
_embedding_layer
def
get_config
(
self
):
return
self
.
_config_dict
@
property
def
transformer_layers
(
self
):
"""List of Transformer layers in the encoder."""
return
self
.
_transformer_layers
@
property
def
pooler_layer
(
self
):
"""The pooler dense layer after the transformer layers."""
return
self
.
_pooler_layer
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
official/nlp/keras_nlp/encoders/bert_encoder_test.py
0 → 100644
View file @
067e8ae3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for transformer-based bert encoder network."""
# Import libraries
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.keras_nlp.encoders
import
bert_encoder
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
class
BertEncoderTest
(
keras_parameterized
.
TestCase
):
def
tearDown
(
self
):
super
(
BertEncoderTest
,
self
).
tearDown
()
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
"float32"
)
def
test_network_creation
(
self
):
hidden_size
=
32
sequence_length
=
21
# Create a small BertEncoder for testing.
test_network
=
bert_encoder
.
BertEncoder
(
vocab_size
=
100
,
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
num_layers
=
3
)
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
dict_outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
data
=
dict_outputs
[
"sequence_output"
]
pooled
=
dict_outputs
[
"pooled_output"
]
self
.
assertIsInstance
(
test_network
.
transformer_layers
,
list
)
self
.
assertLen
(
test_network
.
transformer_layers
,
3
)
self
.
assertIsInstance
(
test_network
.
pooler_layer
,
tf
.
keras
.
layers
.
Dense
)
expected_data_shape
=
[
None
,
sequence_length
,
hidden_size
]
expected_pooled_shape
=
[
None
,
hidden_size
]
self
.
assertAllEqual
(
expected_data_shape
,
data
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_pooled_shape
,
pooled
.
shape
.
as_list
())
# The default output dtype is float32.
self
.
assertAllEqual
(
tf
.
float32
,
data
.
dtype
)
self
.
assertAllEqual
(
tf
.
float32
,
pooled
.
dtype
)
def
test_all_encoder_outputs_network_creation
(
self
):
hidden_size
=
32
sequence_length
=
21
# Create a small BertEncoder for testing.
test_network
=
bert_encoder
.
BertEncoder
(
vocab_size
=
100
,
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
num_layers
=
3
,
return_all_encoder_outputs
=
True
)
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
dict_outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
all_encoder_outputs
=
dict_outputs
[
"encoder_outputs"
]
pooled
=
dict_outputs
[
"pooled_output"
]
expected_data_shape
=
[
None
,
sequence_length
,
hidden_size
]
expected_pooled_shape
=
[
None
,
hidden_size
]
self
.
assertLen
(
all_encoder_outputs
,
3
)
for
data
in
all_encoder_outputs
:
self
.
assertAllEqual
(
expected_data_shape
,
data
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_pooled_shape
,
pooled
.
shape
.
as_list
())
# The default output dtype is float32.
self
.
assertAllEqual
(
tf
.
float32
,
all_encoder_outputs
[
-
1
].
dtype
)
self
.
assertAllEqual
(
tf
.
float32
,
pooled
.
dtype
)
def
test_network_creation_with_float16_dtype
(
self
):
hidden_size
=
32
sequence_length
=
21
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
"mixed_float16"
)
# Create a small BertEncoder for testing.
test_network
=
bert_encoder
.
BertEncoder
(
vocab_size
=
100
,
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
num_layers
=
3
)
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
dict_outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
data
=
dict_outputs
[
"sequence_output"
]
pooled
=
dict_outputs
[
"pooled_output"
]
expected_data_shape
=
[
None
,
sequence_length
,
hidden_size
]
expected_pooled_shape
=
[
None
,
hidden_size
]
self
.
assertAllEqual
(
expected_data_shape
,
data
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_pooled_shape
,
pooled
.
shape
.
as_list
())
# If float_dtype is set to float16, the data output is float32 (from a layer
# norm) and pool output should be float16.
self
.
assertAllEqual
(
tf
.
float32
,
data
.
dtype
)
self
.
assertAllEqual
(
tf
.
float16
,
pooled
.
dtype
)
@
parameterized
.
named_parameters
(
(
"all_sequence"
,
None
,
21
),
(
"output_range"
,
1
,
1
),
)
def
test_network_invocation
(
self
,
output_range
,
out_seq_len
):
hidden_size
=
32
sequence_length
=
21
vocab_size
=
57
num_types
=
7
# Create a small BertEncoder for testing.
test_network
=
bert_encoder
.
BertEncoder
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
num_layers
=
3
,
type_vocab_size
=
num_types
,
output_range
=
output_range
)
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
dict_outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
data
=
dict_outputs
[
"sequence_output"
]
pooled
=
dict_outputs
[
"pooled_output"
]
# Create a model based off of this network:
model
=
tf
.
keras
.
Model
([
word_ids
,
mask
,
type_ids
],
[
data
,
pooled
])
# Invoke the model. We can't validate the output data here (the model is too
# complex) but this will catch structural runtime errors.
batch_size
=
3
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
type_id_data
=
np
.
random
.
randint
(
num_types
,
size
=
(
batch_size
,
sequence_length
))
outputs
=
model
.
predict
([
word_id_data
,
mask_data
,
type_id_data
])
self
.
assertEqual
(
outputs
[
0
].
shape
[
1
],
out_seq_len
)
# Creates a BertEncoder with max_sequence_length != sequence_length
max_sequence_length
=
128
test_network
=
bert_encoder
.
BertEncoder
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
max_sequence_length
=
max_sequence_length
,
num_attention_heads
=
2
,
num_layers
=
3
,
type_vocab_size
=
num_types
)
dict_outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
data
=
dict_outputs
[
"sequence_output"
]
pooled
=
dict_outputs
[
"pooled_output"
]
model
=
tf
.
keras
.
Model
([
word_ids
,
mask
,
type_ids
],
[
data
,
pooled
])
outputs
=
model
.
predict
([
word_id_data
,
mask_data
,
type_id_data
])
self
.
assertEqual
(
outputs
[
0
].
shape
[
1
],
sequence_length
)
# Creates a BertEncoder with embedding_width != hidden_size
test_network
=
bert_encoder
.
BertEncoder
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
max_sequence_length
=
max_sequence_length
,
num_attention_heads
=
2
,
num_layers
=
3
,
type_vocab_size
=
num_types
,
embedding_width
=
16
)
dict_outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
data
=
dict_outputs
[
"sequence_output"
]
pooled
=
dict_outputs
[
"pooled_output"
]
model
=
tf
.
keras
.
Model
([
word_ids
,
mask
,
type_ids
],
[
data
,
pooled
])
outputs
=
model
.
predict
([
word_id_data
,
mask_data
,
type_id_data
])
self
.
assertEqual
(
outputs
[
0
].
shape
[
-
1
],
hidden_size
)
self
.
assertTrue
(
hasattr
(
test_network
,
"_embedding_projection"
))
def
test_serialize_deserialize
(
self
):
# Create a network object that sets all of its config options.
kwargs
=
dict
(
vocab_size
=
100
,
hidden_size
=
32
,
num_layers
=
3
,
num_attention_heads
=
2
,
max_sequence_length
=
21
,
type_vocab_size
=
12
,
inner_dim
=
1223
,
inner_activation
=
"relu"
,
output_dropout
=
0.05
,
attention_dropout
=
0.22
,
initializer
=
"glorot_uniform"
,
return_all_encoder_outputs
=
False
,
output_range
=-
1
,
embedding_width
=
16
)
network
=
bert_encoder
.
BertEncoder
(
**
kwargs
)
expected_config
=
dict
(
kwargs
)
expected_config
[
"inner_activation"
]
=
tf
.
keras
.
activations
.
serialize
(
tf
.
keras
.
activations
.
get
(
expected_config
[
"inner_activation"
]))
expected_config
[
"initializer"
]
=
tf
.
keras
.
initializers
.
serialize
(
tf
.
keras
.
initializers
.
get
(
expected_config
[
"initializer"
]))
self
.
assertEqual
(
network
.
get_config
(),
expected_config
)
# Create another network object from the first object's config.
new_network
=
bert_encoder
.
BertEncoder
.
from_config
(
network
.
get_config
())
# Validate that the config can be forced to JSON.
_
=
network
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
# Tests model saving/loading.
model_path
=
self
.
get_temp_dir
()
+
"/model"
network
.
save
(
model_path
)
_
=
tf
.
keras
.
models
.
load_model
(
model_path
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment