"vscode:/vscode.git/clone" did not exist on "0909bb0d2f87e3d6a73a8e0dc0e38f55ce44a4d4"
Commit d8d53ba8 authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Prepare for upcoming keras initializer change.

PiperOrigin-RevId: 446878547
parent 2600f792
......@@ -17,6 +17,8 @@ import math
import tensorflow as tf
from official.modeling import tf_utils
class Attention(tf.keras.layers.Layer):
"""Multi-headed attention layer."""
......@@ -53,19 +55,19 @@ class Attention(tf.keras.layers.Layer):
self.query_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None,
name="query")
self.key_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None,
name="key")
self.value_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None,
name="value")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment