Commit ba206271 authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

Alias nlp/modeling BertEncoder with keras_nlp BertEncoder.

PiperOrigin-RevId: 332386511
parent 5cee7220
...@@ -14,4 +14,5 @@ ...@@ -14,4 +14,5 @@
# ============================================================================== # ==============================================================================
"""Keras-NLP package definition.""" """Keras-NLP package definition."""
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
from official.nlp.keras_nlp import encoders
from official.nlp.keras_nlp import layers from official.nlp.keras_nlp import layers
...@@ -21,8 +21,11 @@ from official.modeling import activations ...@@ -21,8 +21,11 @@ from official.modeling import activations
from official.nlp import keras_nlp from official.nlp import keras_nlp
# This class is being replaced by keras_nlp.encoders.BertEncoder and merely
# acts as a wrapper if you need: 1) list outputs instead of dict outputs,
# 2) shared embedding layer.
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class BertEncoder(tf.keras.Model): class BertEncoder(keras_nlp.encoders.BertEncoder):
"""Bi-directional Transformer-based encoder network. """Bi-directional Transformer-based encoder network.
This network implements a bi-directional Transformer-based encoder as This network implements a bi-directional Transformer-based encoder as
...@@ -93,145 +96,51 @@ class BertEncoder(tf.keras.Model): ...@@ -93,145 +96,51 @@ class BertEncoder(tf.keras.Model):
embedding_layer=None, embedding_layer=None,
dict_outputs=False, dict_outputs=False,
**kwargs): **kwargs):
activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer)
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._config_dict = { self._embedding_layer_instance = embedding_layer
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size,
'activation': tf.keras.activations.serialize(activation),
'dropout_rate': dropout_rate,
'attention_dropout_rate': attention_dropout_rate,
'initializer': tf.keras.initializers.serialize(initializer),
'return_all_encoder_outputs': return_all_encoder_outputs,
'output_range': output_range,
'embedding_width': embedding_width,
'dict_outputs': dict_outputs,
}
word_ids = tf.keras.layers.Input( super(BertEncoder, self).__init__(
shape=(None,), dtype=tf.int32, name='input_word_ids')
mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_mask')
type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids')
if embedding_width is None:
embedding_width = hidden_size
if embedding_layer is None:
self._embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, hidden_size=hidden_size,
initializer=initializer, num_layers=num_layers,
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
word_embeddings = self._embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings)
self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')
type_embeddings = self._type_embedding_layer(type_ids)
embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
embeddings = self._embedding_norm_layer(embeddings)
embeddings = (tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
embeddings = self._embedding_projection(embeddings)
self._transformer_layers = []
data = embeddings
attention_mask = keras_nlp.layers.SelfAttentionMask()(data, mask)
encoder_outputs = []
for i in range(num_layers):
if i == num_layers - 1 and output_range is not None:
transformer_output_range = output_range
else:
transformer_output_range = None
layer = keras_nlp.layers.TransformerEncoderBlock(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
max_sequence_length=max_sequence_length,
type_vocab_size=type_vocab_size,
inner_dim=intermediate_size, inner_dim=intermediate_size,
inner_activation=activation, inner_activation=activation,
output_dropout=dropout_rate, output_dropout=dropout_rate,
attention_dropout=attention_dropout_rate, attention_dropout=attention_dropout_rate,
output_range=transformer_output_range, initializer=initializer,
kernel_initializer=initializer, return_all_encoder_outputs=return_all_encoder_outputs,
name='transformer/layer_%d' % i) output_range=output_range,
self._transformer_layers.append(layer) embedding_width=embedding_width)
data = layer([data, attention_mask])
encoder_outputs.append(data) # Replace arguments from keras_nlp.encoders.BertEncoder.
self._config_dict['activation'] = self._config_dict.pop('inner_activation')
first_token_tensor = ( self._config_dict['intermediate_size'] = self._config_dict.pop('inner_dim')
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))( self._config_dict['dropout_rate'] = self._config_dict.pop('output_dropout')
encoder_outputs[-1])) self._config_dict['attention_dropout_rate'] = self._config_dict.pop(
self._pooler_layer = tf.keras.layers.Dense( 'attention_dropout')
units=hidden_size, self._config_dict['dict_outputs'] = dict_outputs
activation='tanh',
kernel_initializer=initializer,
name='pooler_transform')
cls_output = self._pooler_layer(first_token_tensor)
if dict_outputs: if dict_outputs:
outputs = dict( return
sequence_output=encoder_outputs[-1], else:
pooled_output=cls_output, nested_output = self._nested_outputs
encoder_outputs=encoder_outputs, cls_output = nested_output['pooled_output']
) if return_all_encoder_outputs:
elif return_all_encoder_outputs: encoder_outputs = nested_output['encoder_outputs']
outputs = [encoder_outputs, cls_output] outputs = [encoder_outputs, cls_output]
else: else:
outputs = [encoder_outputs[-1], cls_output] sequence_output = nested_output['sequence_output']
super(BertEncoder, self).__init__( outputs = [sequence_output, cls_output]
inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs) super(keras_nlp.encoders.BertEncoder, self).__init__(
inputs=self.inputs, outputs=outputs, **kwargs)
def get_embedding_table(self):
return self._embedding_layer.embeddings # Override method for shared embedding use case.
def _build_embedding_layer(self):
def get_embedding_layer(self): if self._embedding_layer_instance is None:
return self._embedding_layer return super(BertEncoder, self)._build_embedding_layer()
else:
def get_config(self): return self._embedding_layer_instance
return self._config_dict
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment