Commit 58edfb5c authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Support adding a layer normalization layer before the final pooling layer.

PiperOrigin-RevId: 343092077
parent 7bfdae1a
......@@ -86,6 +86,9 @@ class EncoderScaffold(tf.keras.Model):
"dropout_rate": The overall dropout rate for the transformer layers.
"attention_dropout_rate": The dropout rate for the attention layers.
"kernel_initializer": The initializer for the transformer layers.
layer_norm_before_pooling: Whether to add a layer norm before the pooling
layer. You probably want to turn this on if you set norm_first=True in
transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers.
dict_outputs: Whether to use a dictionary as the model outputs.
......@@ -101,6 +104,7 @@ class EncoderScaffold(tf.keras.Model):
num_hidden_instances=1,
hidden_cls=layers.Transformer,
hidden_cfg=None,
layer_norm_before_pooling=False,
return_all_layer_outputs=False,
dict_outputs=False,
**kwargs):
......@@ -180,6 +184,14 @@ class EncoderScaffold(tf.keras.Model):
layer_output_data.append(data)
hidden_layers.append(layer)
if layer_norm_before_pooling:
# Normalize the final output.
output_layer_norm = tf.keras.layers.LayerNormalization(
name='final_layer_norm',
axis=-1,
epsilon=1e-12)
layer_output_data[-1] = output_layer_norm(layer_output_data[-1])
last_layer_output = layer_output_data[-1]
# Applying a tf.slice op (through subscript notation) to a Keras tensor
# like this will create a SliceOpLambda layer. This is better than a Lambda
......@@ -221,6 +233,7 @@ class EncoderScaffold(tf.keras.Model):
self._embedding_cls = embedding_cls
self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data
self._layer_norm_before_pooling = layer_norm_before_pooling
self._return_all_layer_outputs = return_all_layer_outputs
self._dict_outputs = dict_outputs
self._kwargs = kwargs
......@@ -232,6 +245,8 @@ class EncoderScaffold(tf.keras.Model):
self._embedding_norm_layer = embedding_norm_layer
self._embedding_network = embedding_network
self._hidden_layers = hidden_layers
if self._layer_norm_before_pooling:
self._output_layer_norm = output_layer_norm
self._pooler_layer = pooler_layer
logging.info('EncoderScaffold configs: %s', self.get_config())
......
......@@ -97,6 +97,7 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
hidden_cls=ValidatedTransformerLayer,
hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg,
layer_norm_before_pooling=True,
return_all_layer_outputs=return_all_layer_outputs)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......@@ -128,6 +129,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
self.assertTrue(hasattr(test_network, "_output_layer_norm"))
def test_network_creation_with_float16_dtype(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
hidden_size = 32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment