Commit a0b4cdb6 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Get the handle of self._embedding_norm_layer for future weight copying.

PiperOrigin-RevId: 328364661
parent 42b49ff1
......@@ -153,12 +153,10 @@ class TransformerEncoder(tf.keras.Model):
embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
embeddings = (
tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm',
axis=-1,
epsilon=1e-12,
dtype=tf.float32)(embeddings))
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
embeddings = self._embedding_norm_layer(embeddings)
embeddings = (tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
# We project the 'embedding' output to 'hidden_size' if it is not already
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment