Commit a1a7aaca authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Get the handle of self._embedding_norm_layer for future weight copying.

PiperOrigin-RevId: 328389991
parent a0b4cdb6
...@@ -163,12 +163,14 @@ class EncoderScaffold(tf.keras.Model): ...@@ -163,12 +163,14 @@ class EncoderScaffold(tf.keras.Model):
embeddings = tf.keras.layers.Add()( embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings]) [word_embeddings, position_embeddings, type_embeddings])
embeddings = (
tf.keras.layers.LayerNormalization( self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', name='embeddings/layer_norm',
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=tf.float32)(embeddings)) dtype=tf.float32)
embeddings = self._embedding_norm_layer(embeddings)
embeddings = ( embeddings = (
tf.keras.layers.Dropout( tf.keras.layers.Dropout(
rate=embedding_cfg['dropout_rate'])(embeddings)) rate=embedding_cfg['dropout_rate'])(embeddings))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment