Commit d8d53ba8 authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Prepare for upcoming keras initializer change.

PiperOrigin-RevId: 446878547
parent 2600f792
...@@ -17,6 +17,8 @@ import math ...@@ -17,6 +17,8 @@ import math
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
class Attention(tf.keras.layers.Layer): class Attention(tf.keras.layers.Layer):
"""Multi-headed attention layer.""" """Multi-headed attention layer."""
...@@ -53,19 +55,19 @@ class Attention(tf.keras.layers.Layer): ...@@ -53,19 +55,19 @@ class Attention(tf.keras.layers.Layer):
self.query_dense_layer = tf.keras.layers.experimental.EinsumDense( self.query_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH", "BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head), output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer, kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None, bias_axes=None,
name="query") name="query")
self.key_dense_layer = tf.keras.layers.experimental.EinsumDense( self.key_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH", "BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head), output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer, kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None, bias_axes=None,
name="key") name="key")
self.value_dense_layer = tf.keras.layers.experimental.EinsumDense( self.value_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH", "BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head), output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer, kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None, bias_axes=None,
name="value") name="value")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment