Commit a1a7aaca authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Get the handle of self._embedding_norm_layer for future weight copying.

PiperOrigin-RevId: 328389991
parent a0b4cdb6
......@@ -163,12 +163,14 @@ class EncoderScaffold(tf.keras.Model):
embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
embeddings = (
tf.keras.layers.LayerNormalization(
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm',
axis=-1,
epsilon=1e-12,
dtype=tf.float32)(embeddings))
dtype=tf.float32)
embeddings = self._embedding_norm_layer(embeddings)
embeddings = (
tf.keras.layers.Dropout(
rate=embedding_cfg['dropout_rate'])(embeddings))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment