Commit a0b4cdb6 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Get the handle of self._embedding_norm_layer for future weight copying.

PiperOrigin-RevId: 328364661
parent 42b49ff1
...@@ -153,12 +153,10 @@ class TransformerEncoder(tf.keras.Model): ...@@ -153,12 +153,10 @@ class TransformerEncoder(tf.keras.Model):
embeddings = tf.keras.layers.Add()( embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings]) [word_embeddings, position_embeddings, type_embeddings])
embeddings = ( self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
tf.keras.layers.LayerNormalization( name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
name='embeddings/layer_norm',
axis=-1, embeddings = self._embedding_norm_layer(embeddings)
epsilon=1e-12,
dtype=tf.float32)(embeddings))
embeddings = (tf.keras.layers.Dropout(rate=dropout_rate)(embeddings)) embeddings = (tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
# We project the 'embedding' output to 'hidden_size' if it is not already # We project the 'embedding' output to 'hidden_size' if it is not already
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment