Commit 5e641c43 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add an interface in nlp.modeling.networks.encoder_scaffold.EncoderScaffold and...

Add an interface in nlp.modeling.networks.encoder_scaffold.EncoderScaffold and nlp.modeling.networks.transformer_encoder.TransformerEncoder, to get the pooler Dense layer reference.

PiperOrigin-RevId: 310675150
parent 98dd890b
...@@ -191,12 +191,12 @@ class EncoderScaffold(network.Network): ...@@ -191,12 +191,12 @@ class EncoderScaffold(network.Network):
first_token_tensor = ( first_token_tensor = (
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))( tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
layer_output_data[-1])) layer_output_data[-1]))
cls_output = tf.keras.layers.Dense( self._pooler_layer = tf.keras.layers.Dense(
units=pooled_output_dim, units=pooled_output_dim,
activation='tanh', activation='tanh',
kernel_initializer=pooler_layer_initializer, kernel_initializer=pooler_layer_initializer,
name='cls_transform')( name='cls_transform')
first_token_tensor) cls_output = self._pooler_layer(first_token_tensor)
if return_all_layer_outputs: if return_all_layer_outputs:
outputs = [layer_output_data, cls_output] outputs = [layer_output_data, cls_output]
...@@ -263,3 +263,8 @@ class EncoderScaffold(network.Network): ...@@ -263,3 +263,8 @@ class EncoderScaffold(network.Network):
def hidden_layers(self): def hidden_layers(self):
"""List of hidden layers in the encoder.""" """List of hidden layers in the encoder."""
return self._hidden_layers return self._hidden_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
...@@ -116,6 +116,7 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -116,6 +116,7 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
data = output_data data = output_data
self.assertIsInstance(test_network.hidden_layers, list) self.assertIsInstance(test_network.hidden_layers, list)
self.assertLen(test_network.hidden_layers, num_hidden_instances) self.assertLen(test_network.hidden_layers, num_hidden_instances)
self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
expected_data_shape = [None, sequence_length, hidden_size] expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size] expected_pooled_shape = [None, hidden_size]
......
...@@ -172,12 +172,12 @@ class TransformerEncoder(network.Network): ...@@ -172,12 +172,12 @@ class TransformerEncoder(network.Network):
first_token_tensor = ( first_token_tensor = (
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))( tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
encoder_outputs[-1])) encoder_outputs[-1]))
cls_output = tf.keras.layers.Dense( self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=initializer,
name='pooler_transform')( name='pooler_transform')
first_token_tensor) cls_output = self._pooler_layer(first_token_tensor)
if return_all_encoder_outputs: if return_all_encoder_outputs:
outputs = [encoder_outputs, cls_output] outputs = [encoder_outputs, cls_output]
...@@ -198,6 +198,11 @@ class TransformerEncoder(network.Network): ...@@ -198,6 +198,11 @@ class TransformerEncoder(network.Network):
"""List of Transformer layers in the encoder.""" """List of Transformer layers in the encoder."""
return self._transformer_layers return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
return cls(**config) return cls(**config)
...@@ -51,6 +51,10 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -51,6 +51,10 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
data, pooled = test_network([word_ids, mask, type_ids]) data, pooled = test_network([word_ids, mask, type_ids])
self.assertIsInstance(test_network.transformer_layers, list)
self.assertLen(test_network.transformer_layers, 3)
self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
expected_data_shape = [None, sequence_length, hidden_size] expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size] expected_pooled_shape = [None, hidden_size]
self.assertAllEqual(expected_data_shape, data.shape.as_list()) self.assertAllEqual(expected_data_shape, data.shape.as_list())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment