Commit 56a2f704 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add interfaces in nlp.modeling.networks.encoder_scaffold.EncoderScaffold,...

Add interfaces in nlp.modeling.networks.encoder_scaffold.EncoderScaffold, allowing it to output all hidden layer references, and all intermediate output data references.

PiperOrigin-RevId: 310509202
parent 3fb1e20f
...@@ -86,6 +86,8 @@ class EncoderScaffold(network.Network): ...@@ -86,6 +86,8 @@ class EncoderScaffold(network.Network):
"dropout_rate": The overall dropout rate for the transformer layers. "dropout_rate": The overall dropout rate for the transformer layers.
"attention_dropout_rate": The dropout rate for the attention layers. "attention_dropout_rate": The dropout rate for the attention layers.
"kernel_initializer": The initializer for the transformer layers. "kernel_initializer": The initializer for the transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers.
""" """
def __init__( def __init__(
...@@ -99,6 +101,7 @@ class EncoderScaffold(network.Network): ...@@ -99,6 +101,7 @@ class EncoderScaffold(network.Network):
num_hidden_instances=1, num_hidden_instances=1,
hidden_cls=layers.Transformer, hidden_cls=layers.Transformer,
hidden_cfg=None, hidden_cfg=None,
return_all_layer_outputs=False,
**kwargs): **kwargs):
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._hidden_cls = hidden_cls self._hidden_cls = hidden_cls
...@@ -109,6 +112,7 @@ class EncoderScaffold(network.Network): ...@@ -109,6 +112,7 @@ class EncoderScaffold(network.Network):
self._embedding_cls = embedding_cls self._embedding_cls = embedding_cls
self._embedding_cfg = embedding_cfg self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data self._embedding_data = embedding_data
self._return_all_layer_outputs = return_all_layer_outputs
self._kwargs = kwargs self._kwargs = kwargs
if embedding_cls: if embedding_cls:
...@@ -173,17 +177,20 @@ class EncoderScaffold(network.Network): ...@@ -173,17 +177,20 @@ class EncoderScaffold(network.Network):
attention_mask = layers.SelfAttentionMask()([embeddings, mask]) attention_mask = layers.SelfAttentionMask()([embeddings, mask])
data = embeddings data = embeddings
layer_output_data = []
self._hidden_layers = []
for _ in range(num_hidden_instances): for _ in range(num_hidden_instances):
if inspect.isclass(hidden_cls): if inspect.isclass(hidden_cls):
layer = self._hidden_cls( layer = hidden_cls(**hidden_cfg) if hidden_cfg else hidden_cls()
**hidden_cfg) if hidden_cfg else self._hidden_cls()
else: else:
layer = self._hidden_cls layer = hidden_cls
data = layer([data, attention_mask]) data = layer([data, attention_mask])
layer_output_data.append(data)
self._hidden_layers.append(layer)
first_token_tensor = ( first_token_tensor = (
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data) tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
) layer_output_data[-1]))
cls_output = tf.keras.layers.Dense( cls_output = tf.keras.layers.Dense(
units=pooled_output_dim, units=pooled_output_dim,
activation='tanh', activation='tanh',
...@@ -191,8 +198,13 @@ class EncoderScaffold(network.Network): ...@@ -191,8 +198,13 @@ class EncoderScaffold(network.Network):
name='cls_transform')( name='cls_transform')(
first_token_tensor) first_token_tensor)
if return_all_layer_outputs:
outputs = [layer_output_data, cls_output]
else:
outputs = [layer_output_data[-1], cls_output]
super(EncoderScaffold, self).__init__( super(EncoderScaffold, self).__init__(
inputs=inputs, outputs=[data, cls_output], **kwargs) inputs=inputs, outputs=outputs, **kwargs)
def get_config(self): def get_config(self):
config_dict = { config_dict = {
...@@ -208,6 +220,8 @@ class EncoderScaffold(network.Network): ...@@ -208,6 +220,8 @@ class EncoderScaffold(network.Network):
self._embedding_cfg, self._embedding_cfg,
'hidden_cfg': 'hidden_cfg':
self._hidden_cfg, self._hidden_cfg,
'return_all_layer_outputs':
self._return_all_layer_outputs,
} }
if inspect.isclass(self._hidden_cls): if inspect.isclass(self._hidden_cls):
config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name( config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name(
...@@ -244,3 +258,8 @@ class EncoderScaffold(network.Network): ...@@ -244,3 +258,8 @@ class EncoderScaffold(network.Network):
'serialization is not yet supported.') % self.name) 'serialization is not yet supported.') % self.name)
else: else:
return self._embedding_data return self._embedding_data
@property
def hidden_layers(self):
"""List of hidden layers in the encoder."""
return self._hidden_layers
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -57,7 +58,10 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -57,7 +58,10 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
super(EncoderScaffoldLayerClassTest, self).tearDown() super(EncoderScaffoldLayerClassTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy("float32") tf.keras.mixed_precision.experimental.set_policy("float32")
def test_network_creation(self): @parameterized.named_parameters(
dict(testcase_name="only_final_output", return_all_layer_outputs=False),
dict(testcase_name="all_layer_outputs", return_all_layer_outputs=True))
def test_network_creation(self, return_all_layer_outputs):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
num_hidden_instances = 3 num_hidden_instances = 3
...@@ -96,12 +100,22 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -96,12 +100,22 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
stddev=0.02), stddev=0.02),
hidden_cls=ValidatedTransformerLayer, hidden_cls=ValidatedTransformerLayer,
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg) embedding_cfg=embedding_cfg,
return_all_layer_outputs=return_all_layer_outputs)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
data, pooled = test_network([word_ids, mask, type_ids]) output_data, pooled = test_network([word_ids, mask, type_ids])
if return_all_layer_outputs:
self.assertIsInstance(output_data, list)
self.assertLen(output_data, num_hidden_instances)
data = output_data[-1]
else:
data = output_data
self.assertIsInstance(test_network.hidden_layers, list)
self.assertLen(test_network.hidden_layers, num_hidden_instances)
expected_data_shape = [None, sequence_length, hidden_size] expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size] expected_pooled_shape = [None, hidden_size]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment