Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
21ce83d8
Commit
21ce83d8
authored
Oct 08, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Oct 08, 2021
Browse files
Merge BertEncoderV2 and BertEncoder.
PiperOrigin-RevId: 401789124
parent
1a8a4662
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
90 additions
and
497 deletions
+90
-497
official/nlp/keras_nlp/encoders/bert_encoder.py
official/nlp/keras_nlp/encoders/bert_encoder.py
+6
-242
official/nlp/keras_nlp/encoders/bert_encoder_test.py
official/nlp/keras_nlp/encoders/bert_encoder_test.py
+0
-6
official/nlp/modeling/networks/__init__.py
official/nlp/modeling/networks/__init__.py
+0
-1
official/nlp/modeling/networks/bert_encoder.py
official/nlp/modeling/networks/bert_encoder.py
+59
-166
official/nlp/modeling/networks/bert_encoder_test.py
official/nlp/modeling/networks/bert_encoder_test.py
+25
-82
No files found.
official/nlp/keras_nlp/encoders/bert_encoder.py
View file @
21ce83d8
...
...
@@ -15,254 +15,18 @@
"""Bert encoder network."""
# pylint: disable=g-classes-have-attributes
import
collections
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp.
keras_nlp
import
layer
s
from
official.nlp.
modeling
import
network
s
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'keras_nlp'
)
class
BertEncoder
(
tf
.
keras
.
Model
):
"""Bi-directional Transformer-based encoder network.
This network implements a bi-directional Transformer-based encoder as
described in "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the
embedding lookups and transformer layers, but not the masked language model
or classification task networks.
The default values for this object are taken from the BERT-Base implementation
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding".
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers
within the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to
generate embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
"""
class
BertEncoder
(
networks
.
BertEncoder
):
"""Deprecated."""
def
__init__
(
self
,
vocab_size
,
hidden_size
=
768
,
num_layers
=
12
,
num_attention_heads
=
12
,
max_sequence_length
=
512
,
type_vocab_size
=
16
,
inner_dim
=
3072
,
inner_activation
=
lambda
x
:
tf
.
keras
.
activations
.
gelu
(
x
,
approximate
=
True
),
output_dropout
=
0.1
,
attention_dropout
=
0.1
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
output_range
=
None
,
embedding_width
=
None
,
embedding_layer
=
None
,
norm_first
=
False
,
**
kwargs
):
activation
=
tf
.
keras
.
activations
.
get
(
inner_activation
)
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
)
mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
if
embedding_width
is
None
:
embedding_width
=
hidden_size
if
embedding_layer
is
None
:
embedding_layer_inst
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
name
=
'word_embeddings'
)
else
:
embedding_layer_inst
=
embedding_layer
word_embeddings
=
embedding_layer_inst
(
word_ids
)
# Always uses dynamic slicing for simplicity.
position_embedding_layer
=
layers
.
PositionEmbedding
(
initializer
=
initializer
,
max_length
=
max_sequence_length
,
name
=
'position_embedding'
)
position_embeddings
=
position_embedding_layer
(
word_embeddings
)
type_embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
type_vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
use_one_hot
=
True
,
name
=
'type_embeddings'
)
type_embeddings
=
type_embedding_layer
(
type_ids
)
embeddings
=
tf
.
keras
.
layers
.
Add
()(
[
word_embeddings
,
position_embeddings
,
type_embeddings
])
embedding_norm_layer
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
'embeddings/layer_norm'
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
embeddings
=
embedding_norm_layer
(
embeddings
)
embeddings
=
(
tf
.
keras
.
layers
.
Dropout
(
rate
=
output_dropout
)(
embeddings
))
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if
embedding_width
!=
hidden_size
:
embedding_projection
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
'...x,xy->...y'
,
output_shape
=
hidden_size
,
bias_axes
=
'y'
,
kernel_initializer
=
initializer
,
name
=
'embedding_projection'
)
embeddings
=
embedding_projection
(
embeddings
)
else
:
embedding_projection
=
None
transformer_layers
=
[]
data
=
embeddings
attention_mask
=
layers
.
SelfAttentionMask
()(
data
,
mask
)
encoder_outputs
=
[]
for
i
in
range
(
num_layers
):
if
i
==
num_layers
-
1
and
output_range
is
not
None
:
transformer_output_range
=
output_range
else
:
transformer_output_range
=
None
layer
=
layers
.
TransformerEncoderBlock
(
num_attention_heads
=
num_attention_heads
,
inner_dim
=
inner_dim
,
inner_activation
=
inner_activation
,
output_dropout
=
output_dropout
,
attention_dropout
=
attention_dropout
,
norm_first
=
norm_first
,
output_range
=
transformer_output_range
,
kernel_initializer
=
initializer
,
name
=
'transformer/layer_%d'
%
i
)
transformer_layers
.
append
(
layer
)
data
=
layer
([
data
,
attention_mask
])
encoder_outputs
.
append
(
data
)
last_encoder_output
=
encoder_outputs
[
-
1
]
# Applying a tf.slice op (through subscript notation) to a Keras tensor
# like this will create a SliceOpLambda layer. This is better than a Lambda
# layer with Python code, because that is fundamentally less portable.
first_token_tensor
=
last_encoder_output
[:,
0
,
:]
pooler_layer
=
tf
.
keras
.
layers
.
Dense
(
units
=
hidden_size
,
activation
=
'tanh'
,
kernel_initializer
=
initializer
,
name
=
'pooler_transform'
)
cls_output
=
pooler_layer
(
first_token_tensor
)
outputs
=
dict
(
sequence_output
=
encoder_outputs
[
-
1
],
pooled_output
=
cls_output
,
encoder_outputs
=
encoder_outputs
,
)
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super
(
BertEncoder
,
self
).
__init__
(
inputs
=
[
word_ids
,
mask
,
type_ids
],
outputs
=
outputs
,
**
kwargs
)
config_dict
=
{
'vocab_size'
:
vocab_size
,
'hidden_size'
:
hidden_size
,
'num_layers'
:
num_layers
,
'num_attention_heads'
:
num_attention_heads
,
'max_sequence_length'
:
max_sequence_length
,
'type_vocab_size'
:
type_vocab_size
,
'inner_dim'
:
inner_dim
,
'inner_activation'
:
tf
.
keras
.
activations
.
serialize
(
activation
),
'output_dropout'
:
output_dropout
,
'attention_dropout'
:
attention_dropout
,
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
initializer
),
'output_range'
:
output_range
,
'embedding_width'
:
embedding_width
,
'embedding_layer'
:
embedding_layer
,
'norm_first'
:
norm_first
,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls
=
collections
.
namedtuple
(
'Config'
,
config_dict
.
keys
())
self
.
_config
=
config_cls
(
**
config_dict
)
self
.
_pooler_layer
=
pooler_layer
self
.
_transformer_layers
=
transformer_layers
self
.
_embedding_norm_layer
=
embedding_norm_layer
self
.
_embedding_layer
=
embedding_layer_inst
self
.
_position_embedding_layer
=
position_embedding_layer
self
.
_type_embedding_layer
=
type_embedding_layer
if
embedding_projection
is
not
None
:
self
.
_embedding_projection
=
embedding_projection
def
get_embedding_table
(
self
):
return
self
.
_embedding_layer
.
embeddings
def
get_embedding_layer
(
self
):
return
self
.
_embedding_layer
def
get_config
(
self
):
return
dict
(
self
.
_config
.
_asdict
())
@
property
def
transformer_layers
(
self
):
"""List of Transformer layers in the encoder."""
return
self
.
_transformer_layers
@
property
def
pooler_layer
(
self
):
"""The pooler dense layer after the transformer layers."""
return
self
.
_pooler_layer
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
if
'embedding_layer'
in
config
and
config
[
'embedding_layer'
]
is
not
None
:
warn_string
=
(
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.'
)
print
(
'WARNING: '
+
warn_string
)
logging
.
warn
(
warn_string
)
return
cls
(
**
config
)
if
'dict_outputs'
in
kwargs
:
kwargs
.
pop
(
'dict_outputs'
)
super
().
__init__
(
dict_outputs
=
True
,
**
kwargs
)
official/nlp/keras_nlp/encoders/bert_encoder_test.py
View file @
21ce83d8
...
...
@@ -213,16 +213,10 @@ class BertEncoderTest(keras_parameterized.TestCase):
tf
.
keras
.
activations
.
get
(
expected_config
[
"inner_activation"
]))
expected_config
[
"initializer"
]
=
tf
.
keras
.
initializers
.
serialize
(
tf
.
keras
.
initializers
.
get
(
expected_config
[
"initializer"
]))
self
.
assertEqual
(
network
.
get_config
(),
expected_config
)
# Create another network object from the first object's config.
new_network
=
bert_encoder
.
BertEncoder
.
from_config
(
network
.
get_config
())
# Validate that the config can be forced to JSON.
_
=
network
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
# Tests model saving/loading.
model_path
=
self
.
get_temp_dir
()
+
"/model"
network
.
save
(
model_path
)
...
...
official/nlp/modeling/networks/__init__.py
View file @
21ce83d8
...
...
@@ -20,7 +20,6 @@ handled object with a standardized configuration.
"""
from
official.nlp.modeling.networks.albert_encoder
import
AlbertEncoder
from
official.nlp.modeling.networks.bert_encoder
import
BertEncoder
from
official.nlp.modeling.networks.bert_encoder
import
BertEncoderV2
from
official.nlp.modeling.networks.classification
import
Classification
from
official.nlp.modeling.networks.encoder_scaffold
import
EncoderScaffold
from
official.nlp.modeling.networks.funnel_transformer
import
FunnelTransformerEncoder
...
...
official/nlp/modeling/networks/bert_encoder.py
View file @
21ce83d8
...
...
@@ -12,20 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer-based
text
encoder network."""
"""Transformer-based
BERT
encoder network."""
# pylint: disable=g-classes-have-attributes
import
collections
from
absl
import
logging
import
tensorflow
as
tf
from
official.modeling
import
activations
from
official.nlp.modeling
import
layers
# TODO(b/202413395): Merge V2 and V1.
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
BertEncoder
V2
(
tf
.
keras
.
Model
):
class
BertEncoder
(
tf
.
keras
.
Model
):
"""Bi-directional Transformer-based encoder network.
This network implements a bi-directional Transformer-based encoder as
...
...
@@ -74,6 +71,11 @@ class BertEncoderV2(tf.keras.Model):
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
dict_outputs: Whether to use a dictionary as the model outputs.
return_all_encoder_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers. Note: when the following `dict_outputs`
argument is True, all encoder outputs are always returned in the dict,
keyed by `encoder_outputs`.
"""
def
__init__
(
...
...
@@ -93,7 +95,28 @@ class BertEncoderV2(tf.keras.Model):
embedding_width
=
None
,
embedding_layer
=
None
,
norm_first
=
False
,
dict_outputs
=
False
,
return_all_encoder_outputs
=
False
,
**
kwargs
):
if
'sequence_length'
in
kwargs
:
kwargs
.
pop
(
'sequence_length'
)
logging
.
warning
(
'`sequence_length` is a deprecated argument to '
'`BertEncoder`, which has no effect for a while. Please '
'remove `sequence_length` argument.'
)
# Handles backward compatible kwargs.
if
'intermediate_size'
in
kwargs
:
inner_dim
=
kwargs
.
pop
(
'intermediate_size'
)
if
'activation'
in
kwargs
:
inner_activation
=
kwargs
.
pop
(
'activation'
)
if
'dropout_rate'
in
kwargs
:
output_dropout
=
kwargs
.
pop
(
'dropout_rate'
)
if
'attention_dropout_rate'
in
kwargs
:
attention_dropout
=
kwargs
.
pop
(
'attention_dropout_rate'
)
activation
=
tf
.
keras
.
activations
.
get
(
inner_activation
)
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
...
...
@@ -194,14 +217,30 @@ class BertEncoderV2(tf.keras.Model):
encoder_outputs
=
encoder_outputs
,
)
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super
(
BertEncoderV2
,
self
).
__init__
(
if
dict_outputs
:
super
().
__init__
(
inputs
=
[
word_ids
,
mask
,
type_ids
],
outputs
=
outputs
,
**
kwargs
)
else
:
cls_output
=
outputs
[
'pooled_output'
]
if
return_all_encoder_outputs
:
encoder_outputs
=
outputs
[
'encoder_outputs'
]
outputs
=
[
encoder_outputs
,
cls_output
]
else
:
sequence_output
=
outputs
[
'sequence_output'
]
outputs
=
[
sequence_output
,
cls_output
]
super
().
__init__
(
# pylint: disable=bad-super-call
inputs
=
[
word_ids
,
mask
,
type_ids
],
outputs
=
outputs
,
**
kwargs
)
self
.
_pooler_layer
=
pooler_layer
self
.
_transformer_layers
=
transformer_layers
self
.
_embedding_norm_layer
=
embedding_norm_layer
self
.
_embedding_layer
=
embedding_layer_inst
self
.
_position_embedding_layer
=
position_embedding_layer
self
.
_type_embedding_layer
=
type_embedding_layer
if
embedding_projection
is
not
None
:
self
.
_embedding_projection
=
embedding_projection
config_dict
=
{
'vocab_size'
:
vocab_size
,
...
...
@@ -219,23 +258,13 @@ class BertEncoderV2(tf.keras.Model):
'embedding_width'
:
embedding_width
,
'embedding_layer'
:
embedding_layer
,
'norm_first'
:
norm_first
,
'dict_outputs'
:
dict_outputs
,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls
=
collections
.
namedtuple
(
'Config'
,
config_dict
.
keys
())
self
.
_config
=
config_cls
(
**
config_dict
)
self
.
_pooler_layer
=
pooler_layer
self
.
_transformer_layers
=
transformer_layers
self
.
_embedding_norm_layer
=
embedding_norm_layer
self
.
_embedding_layer
=
embedding_layer_inst
self
.
_position_embedding_layer
=
position_embedding_layer
self
.
_type_embedding_layer
=
type_embedding_layer
if
embedding_projection
is
not
None
:
self
.
_embedding_projection
=
embedding_projection
# pylint: disable=protected-access
self
.
_setattr_tracking
=
False
self
.
_config
=
config_dict
self
.
_setattr_tracking
=
True
# pylint: enable=protected-access
def
get_embedding_table
(
self
):
return
self
.
_embedding_layer
.
embeddings
...
...
@@ -244,7 +273,7 @@ class BertEncoderV2(tf.keras.Model):
return
self
.
_embedding_layer
def
get_config
(
self
):
return
dict
(
self
.
_config
.
_asdict
())
return
self
.
_config
@
property
def
transformer_layers
(
self
):
...
...
@@ -268,139 +297,3 @@ class BertEncoderV2(tf.keras.Model):
logging
.
warn
(
warn_string
)
return
cls
(
**
config
)
# This class is being replaced by BertEncoderV2 and merely
# acts as a wrapper if you need: 1) list outputs instead of dict outputs,
# 2) shared embedding layer.
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
BertEncoder
(
BertEncoderV2
):
"""Bi-directional Transformer-based encoder network.
This network implements a bi-directional Transformer-based encoder as
described in "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the
embedding lookups and transformer layers, but not the masked language model
or classification task networks.
The default values for this object are taken from the BERT-Base implementation
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding".
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
intermediate_size: The intermediate size for the transformer layers.
activation: The activation to use for the transformer layers.
dropout_rate: The dropout rate to use for the transformer layers.
attention_dropout_rate: The dropout rate to use for the attention layers
within the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
return_all_encoder_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers. Note: when the following `dict_outputs`
argument is True, all encoder outputs are always returned in the dict,
keyed by `encoder_outputs`.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of `(vocab_size, embedding_width)` and
`(embedding_width, hidden_size)`, where `embedding_width` is usually much
smaller than `hidden_size`.
embedding_layer: The word embedding layer. `None` means we will create a new
embedding layer. Otherwise, we will reuse the given embedding layer. This
parameter is originally added for ELECTRA model which needs to tie the
generator embeddings with the discriminator embeddings.
dict_outputs: Whether to use a dictionary as the model outputs.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
"""
def
__init__
(
self
,
vocab_size
,
hidden_size
=
768
,
num_layers
=
12
,
num_attention_heads
=
12
,
max_sequence_length
=
512
,
type_vocab_size
=
16
,
intermediate_size
=
3072
,
activation
=
activations
.
gelu
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
return_all_encoder_outputs
=
False
,
output_range
=
None
,
embedding_width
=
None
,
embedding_layer
=
None
,
dict_outputs
=
False
,
norm_first
=
False
,
**
kwargs
):
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super
(
BertEncoder
,
self
).
__init__
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
num_layers
=
num_layers
,
num_attention_heads
=
num_attention_heads
,
max_sequence_length
=
max_sequence_length
,
type_vocab_size
=
type_vocab_size
,
inner_dim
=
intermediate_size
,
inner_activation
=
activation
,
output_dropout
=
dropout_rate
,
attention_dropout
=
attention_dropout_rate
,
initializer
=
initializer
,
output_range
=
output_range
,
embedding_width
=
embedding_width
,
embedding_layer
=
embedding_layer
,
norm_first
=
norm_first
)
if
'sequence_length'
in
kwargs
:
kwargs
.
pop
(
'sequence_length'
)
logging
.
warning
(
'`sequence_length` is a deprecated argument to '
'`BertEncoder`, which has no effect for a while. Please '
'remove `sequence_length` argument.'
)
self
.
_embedding_layer_instance
=
embedding_layer
# Replace arguments from keras_nlp.encoders.BertEncoder.
config_dict
=
self
.
_config
.
_asdict
()
config_dict
[
'activation'
]
=
config_dict
.
pop
(
'inner_activation'
)
config_dict
[
'intermediate_size'
]
=
config_dict
.
pop
(
'inner_dim'
)
config_dict
[
'dropout_rate'
]
=
config_dict
.
pop
(
'output_dropout'
)
config_dict
[
'attention_dropout_rate'
]
=
config_dict
.
pop
(
'attention_dropout'
)
config_dict
[
'dict_outputs'
]
=
dict_outputs
config_dict
[
'return_all_encoder_outputs'
]
=
return_all_encoder_outputs
config_cls
=
collections
.
namedtuple
(
'Config'
,
config_dict
.
keys
())
self
.
_config
=
config_cls
(
**
config_dict
)
if
dict_outputs
:
return
else
:
nested_output
=
self
.
_nested_outputs
cls_output
=
nested_output
[
'pooled_output'
]
if
return_all_encoder_outputs
:
encoder_outputs
=
nested_output
[
'encoder_outputs'
]
outputs
=
[
encoder_outputs
,
cls_output
]
else
:
sequence_output
=
nested_output
[
'sequence_output'
]
outputs
=
[
sequence_output
,
cls_output
]
super
(
BertEncoderV2
,
self
).
__init__
(
# pylint: disable=bad-super-call
inputs
=
self
.
inputs
,
outputs
=
outputs
,
**
kwargs
)
official/nlp/modeling/networks/bert_encoder_test.py
View file @
21ce83d8
...
...
@@ -26,21 +26,22 @@ from official.nlp.modeling.networks import bert_encoder
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
class
BertEncoder
V2
Test
(
keras_parameterized
.
TestCase
):
class
BertEncoderTest
(
keras_parameterized
.
TestCase
):
def
tearDown
(
self
):
super
(
BertEncoder
V2
Test
,
self
).
tearDown
()
super
(
BertEncoderTest
,
self
).
tearDown
()
tf
.
keras
.
mixed_precision
.
set_global_policy
(
"float32"
)
def
test_network_creation
(
self
):
def
test_
v2_
network_creation
(
self
):
hidden_size
=
32
sequence_length
=
21
# Create a small BertEncoder for testing.
test_network
=
bert_encoder
.
BertEncoder
V2
(
test_network
=
bert_encoder
.
BertEncoder
(
vocab_size
=
100
,
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
num_layers
=
3
)
num_layers
=
3
,
dict_outputs
=
True
)
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
...
...
@@ -62,15 +63,16 @@ class BertEncoderV2Test(keras_parameterized.TestCase):
self
.
assertAllEqual
(
tf
.
float32
,
data
.
dtype
)
self
.
assertAllEqual
(
tf
.
float32
,
pooled
.
dtype
)
def
test_all_encoder_outputs_network_creation
(
self
):
def
test_
v2_
all_encoder_outputs_network_creation
(
self
):
hidden_size
=
32
sequence_length
=
21
# Create a small BertEncoder for testing.
test_network
=
bert_encoder
.
BertEncoder
V2
(
test_network
=
bert_encoder
.
BertEncoder
(
vocab_size
=
100
,
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
num_layers
=
3
)
num_layers
=
3
,
dict_outputs
=
True
)
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
...
...
@@ -90,16 +92,17 @@ class BertEncoderV2Test(keras_parameterized.TestCase):
self
.
assertAllEqual
(
tf
.
float32
,
all_encoder_outputs
[
-
1
].
dtype
)
self
.
assertAllEqual
(
tf
.
float32
,
pooled
.
dtype
)
def
test_network_creation_with_float16_dtype
(
self
):
def
test_
v2_
network_creation_with_float16_dtype
(
self
):
hidden_size
=
32
sequence_length
=
21
tf
.
keras
.
mixed_precision
.
set_global_policy
(
"mixed_float16"
)
# Create a small BertEncoder for testing.
test_network
=
bert_encoder
.
BertEncoder
V2
(
test_network
=
bert_encoder
.
BertEncoder
(
vocab_size
=
100
,
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
num_layers
=
3
)
num_layers
=
3
,
dict_outputs
=
True
)
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
...
...
@@ -122,19 +125,20 @@ class BertEncoderV2Test(keras_parameterized.TestCase):
(
"all_sequence"
,
None
,
21
),
(
"output_range"
,
1
,
1
),
)
def
test_network_invocation
(
self
,
output_range
,
out_seq_len
):
def
test_
v2_
network_invocation
(
self
,
output_range
,
out_seq_len
):
hidden_size
=
32
sequence_length
=
21
vocab_size
=
57
num_types
=
7
# Create a small BertEncoder for testing.
test_network
=
bert_encoder
.
BertEncoder
V2
(
test_network
=
bert_encoder
.
BertEncoder
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
num_layers
=
3
,
type_vocab_size
=
num_types
,
output_range
=
output_range
)
output_range
=
output_range
,
dict_outputs
=
True
)
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
...
...
@@ -159,13 +163,14 @@ class BertEncoderV2Test(keras_parameterized.TestCase):
# Creates a BertEncoder with max_sequence_length != sequence_length
max_sequence_length
=
128
test_network
=
bert_encoder
.
BertEncoder
V2
(
test_network
=
bert_encoder
.
BertEncoder
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
max_sequence_length
=
max_sequence_length
,
num_attention_heads
=
2
,
num_layers
=
3
,
type_vocab_size
=
num_types
)
type_vocab_size
=
num_types
,
dict_outputs
=
True
)
dict_outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
data
=
dict_outputs
[
"sequence_output"
]
pooled
=
dict_outputs
[
"pooled_output"
]
...
...
@@ -174,14 +179,15 @@ class BertEncoderV2Test(keras_parameterized.TestCase):
self
.
assertEqual
(
outputs
[
0
].
shape
[
1
],
sequence_length
)
# Creates a BertEncoder with embedding_width != hidden_size
test_network
=
bert_encoder
.
BertEncoder
V2
(
test_network
=
bert_encoder
.
BertEncoder
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
max_sequence_length
=
max_sequence_length
,
num_attention_heads
=
2
,
num_layers
=
3
,
type_vocab_size
=
num_types
,
embedding_width
=
16
)
embedding_width
=
16
,
dict_outputs
=
True
)
dict_outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
data
=
dict_outputs
[
"sequence_output"
]
pooled
=
dict_outputs
[
"pooled_output"
]
...
...
@@ -208,37 +214,16 @@ class BertEncoderV2Test(keras_parameterized.TestCase):
embedding_width
=
16
,
embedding_layer
=
None
,
norm_first
=
False
)
network
=
bert_encoder
.
BertEncoderV2
(
**
kwargs
)
expected_config
=
dict
(
kwargs
)
expected_config
[
"inner_activation"
]
=
tf
.
keras
.
activations
.
serialize
(
tf
.
keras
.
activations
.
get
(
expected_config
[
"inner_activation"
]))
expected_config
[
"initializer"
]
=
tf
.
keras
.
initializers
.
serialize
(
tf
.
keras
.
initializers
.
get
(
expected_config
[
"initializer"
]))
self
.
assertEqual
(
network
.
get_config
(),
expected_config
)
# Create another network object from the first object's config.
new_network
=
bert_encoder
.
BertEncoderV2
.
from_config
(
network
.
get_config
())
network
=
bert_encoder
.
BertEncoder
(
**
kwargs
)
# Validate that the config can be forced to JSON.
_
=
network
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
# Tests model saving/loading.
model_path
=
self
.
get_temp_dir
()
+
"/model"
network
.
save
(
model_path
)
_
=
tf
.
keras
.
models
.
load_model
(
model_path
)
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
class
BertEncoderTest
(
keras_parameterized
.
TestCase
):
def
tearDown
(
self
):
super
(
BertEncoderTest
,
self
).
tearDown
()
tf
.
keras
.
mixed_precision
.
set_global_policy
(
"float32"
)
def
test_network_creation
(
self
):
hidden_size
=
32
sequence_length
=
21
...
...
@@ -415,48 +400,6 @@ class BertEncoderTest(keras_parameterized.TestCase):
self
.
assertEqual
(
outputs
[
0
].
shape
[
-
1
],
hidden_size
)
self
.
assertTrue
(
hasattr
(
test_network
,
"_embedding_projection"
))
def
test_serialize_deserialize
(
self
):
# Create a network object that sets all of its config options.
kwargs
=
dict
(
vocab_size
=
100
,
hidden_size
=
32
,
num_layers
=
3
,
num_attention_heads
=
2
,
max_sequence_length
=
21
,
type_vocab_size
=
12
,
intermediate_size
=
1223
,
activation
=
"relu"
,
dropout_rate
=
0.05
,
attention_dropout_rate
=
0.22
,
initializer
=
"glorot_uniform"
,
return_all_encoder_outputs
=
False
,
output_range
=-
1
,
embedding_width
=
16
,
dict_outputs
=
True
,
embedding_layer
=
None
,
norm_first
=
False
)
network
=
bert_encoder
.
BertEncoder
(
**
kwargs
)
expected_config
=
dict
(
kwargs
)
expected_config
[
"activation"
]
=
tf
.
keras
.
activations
.
serialize
(
tf
.
keras
.
activations
.
get
(
expected_config
[
"activation"
]))
expected_config
[
"initializer"
]
=
tf
.
keras
.
initializers
.
serialize
(
tf
.
keras
.
initializers
.
get
(
expected_config
[
"initializer"
]))
self
.
assertEqual
(
network
.
get_config
(),
expected_config
)
# Create another network object from the first object's config.
new_network
=
bert_encoder
.
BertEncoder
.
from_config
(
network
.
get_config
())
# Validate that the config can be forced to JSON.
_
=
network
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
# Tests model saving/loading.
model_path
=
self
.
get_temp_dir
()
+
"/model"
network
.
save
(
model_path
)
_
=
tf
.
keras
.
models
.
load_model
(
model_path
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment