Commit 5e641c43 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add an interface in nlp.modeling.networks.encoder_scaffold.EncoderScaffold and...

Add an interface in nlp.modeling.networks.encoder_scaffold.EncoderScaffold and nlp.modeling.networks.transformer_encoder.TransformerEncoder, to get the pooler Dense layer reference.

PiperOrigin-RevId: 310675150
parent 98dd890b
......@@ -191,12 +191,12 @@ class EncoderScaffold(network.Network):
first_token_tensor = (
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
layer_output_data[-1]))
cls_output = tf.keras.layers.Dense(
self._pooler_layer = tf.keras.layers.Dense(
units=pooled_output_dim,
activation='tanh',
kernel_initializer=pooler_layer_initializer,
name='cls_transform')(
first_token_tensor)
name='cls_transform')
cls_output = self._pooler_layer(first_token_tensor)
if return_all_layer_outputs:
outputs = [layer_output_data, cls_output]
......@@ -263,3 +263,8 @@ class EncoderScaffold(network.Network):
def hidden_layers(self):
"""List of hidden layers in the encoder."""
return self._hidden_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
......@@ -116,6 +116,7 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
data = output_data
self.assertIsInstance(test_network.hidden_layers, list)
self.assertLen(test_network.hidden_layers, num_hidden_instances)
self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size]
......
......@@ -172,12 +172,12 @@ class TransformerEncoder(network.Network):
first_token_tensor = (
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
encoder_outputs[-1]))
cls_output = tf.keras.layers.Dense(
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
name='pooler_transform')(
first_token_tensor)
name='pooler_transform')
cls_output = self._pooler_layer(first_token_tensor)
if return_all_encoder_outputs:
outputs = [encoder_outputs, cls_output]
......@@ -198,6 +198,11 @@ class TransformerEncoder(network.Network):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
......@@ -51,6 +51,10 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
data, pooled = test_network([word_ids, mask, type_ids])
self.assertIsInstance(test_network.transformer_layers, list)
self.assertLen(test_network.transformer_layers, 3)
self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size]
self.assertAllEqual(expected_data_shape, data.shape.as_list())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment