"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "f322c788208c66ef5c38a4bc8b6a909f034c0889"
Commit 58edfb5c authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Support adding a layer normalization layer before the final pooling layer.

PiperOrigin-RevId: 343092077
parent 7bfdae1a
...@@ -86,6 +86,9 @@ class EncoderScaffold(tf.keras.Model): ...@@ -86,6 +86,9 @@ class EncoderScaffold(tf.keras.Model):
"dropout_rate": The overall dropout rate for the transformer layers. "dropout_rate": The overall dropout rate for the transformer layers.
"attention_dropout_rate": The dropout rate for the attention layers. "attention_dropout_rate": The dropout rate for the attention layers.
"kernel_initializer": The initializer for the transformer layers. "kernel_initializer": The initializer for the transformer layers.
layer_norm_before_pooling: Whether to add a layer norm before the pooling
layer. You probably want to turn this on if you set norm_first=True in
transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers. all encoder transformer layers.
dict_outputs: Whether to use a dictionary as the model outputs. dict_outputs: Whether to use a dictionary as the model outputs.
...@@ -101,6 +104,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -101,6 +104,7 @@ class EncoderScaffold(tf.keras.Model):
num_hidden_instances=1, num_hidden_instances=1,
hidden_cls=layers.Transformer, hidden_cls=layers.Transformer,
hidden_cfg=None, hidden_cfg=None,
layer_norm_before_pooling=False,
return_all_layer_outputs=False, return_all_layer_outputs=False,
dict_outputs=False, dict_outputs=False,
**kwargs): **kwargs):
...@@ -180,6 +184,14 @@ class EncoderScaffold(tf.keras.Model): ...@@ -180,6 +184,14 @@ class EncoderScaffold(tf.keras.Model):
layer_output_data.append(data) layer_output_data.append(data)
hidden_layers.append(layer) hidden_layers.append(layer)
if layer_norm_before_pooling:
# Normalize the final output.
output_layer_norm = tf.keras.layers.LayerNormalization(
name='final_layer_norm',
axis=-1,
epsilon=1e-12)
layer_output_data[-1] = output_layer_norm(layer_output_data[-1])
last_layer_output = layer_output_data[-1] last_layer_output = layer_output_data[-1]
# Applying a tf.slice op (through subscript notation) to a Keras tensor # Applying a tf.slice op (through subscript notation) to a Keras tensor
# like this will create a SliceOpLambda layer. This is better than a Lambda # like this will create a SliceOpLambda layer. This is better than a Lambda
...@@ -221,6 +233,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -221,6 +233,7 @@ class EncoderScaffold(tf.keras.Model):
self._embedding_cls = embedding_cls self._embedding_cls = embedding_cls
self._embedding_cfg = embedding_cfg self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data self._embedding_data = embedding_data
self._layer_norm_before_pooling = layer_norm_before_pooling
self._return_all_layer_outputs = return_all_layer_outputs self._return_all_layer_outputs = return_all_layer_outputs
self._dict_outputs = dict_outputs self._dict_outputs = dict_outputs
self._kwargs = kwargs self._kwargs = kwargs
...@@ -232,6 +245,8 @@ class EncoderScaffold(tf.keras.Model): ...@@ -232,6 +245,8 @@ class EncoderScaffold(tf.keras.Model):
self._embedding_norm_layer = embedding_norm_layer self._embedding_norm_layer = embedding_norm_layer
self._embedding_network = embedding_network self._embedding_network = embedding_network
self._hidden_layers = hidden_layers self._hidden_layers = hidden_layers
if self._layer_norm_before_pooling:
self._output_layer_norm = output_layer_norm
self._pooler_layer = pooler_layer self._pooler_layer = pooler_layer
logging.info('EncoderScaffold configs: %s', self.get_config()) logging.info('EncoderScaffold configs: %s', self.get_config())
......
...@@ -97,6 +97,7 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -97,6 +97,7 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
hidden_cls=ValidatedTransformerLayer, hidden_cls=ValidatedTransformerLayer,
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg, embedding_cfg=embedding_cfg,
layer_norm_before_pooling=True,
return_all_layer_outputs=return_all_layer_outputs) return_all_layer_outputs=return_all_layer_outputs)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -128,6 +129,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -128,6 +129,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
self.assertNotEmpty(call_list) self.assertNotEmpty(call_list)
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.") self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
self.assertTrue(hasattr(test_network, "_output_layer_norm"))
def test_network_creation_with_float16_dtype(self): def test_network_creation_with_float16_dtype(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16") tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
hidden_size = 32 hidden_size = 32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment