Commit 7278d89b authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Prepare for upcoming keras initializer change.

PiperOrigin-RevId: 451474703
parent 5b964dbb
......@@ -20,6 +20,7 @@ cross-attention layer.
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import models
......@@ -91,8 +92,9 @@ class TransformerEncoder(tf.keras.layers.Layer):
norm_first=self._norm_first,
norm_epsilon=self._norm_epsilon,
inner_dropout=self._intermediate_dropout,
attention_initializer=models.seq2seq_transformer
.attention_initializer(input_shape[2]),
attention_initializer=tf_utils.clone_initializer(
models.seq2seq_transformer.attention_initializer(
input_shape[2])),
name=("layer_%d" % i)))
self.output_normalization = tf.keras.layers.LayerNormalization(
epsilon=self._norm_epsilon, dtype="float32")
......@@ -234,7 +236,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes
def build(self, input_shape):
......@@ -284,7 +287,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
......@@ -302,7 +305,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
......@@ -490,8 +493,9 @@ class TransformerDecoder(tf.keras.layers.Layer):
norm_first=self._norm_first,
norm_epsilon=self._norm_epsilon,
intermediate_dropout=self._intermediate_dropout,
attention_initializer=models.seq2seq_transformer
.attention_initializer(input_shape[2]),
attention_initializer=tf_utils.clone_initializer(
models.seq2seq_transformer.attention_initializer(
input_shape[2])),
name=("layer_%d" % i)))
self.output_normalization = tf.keras.layers.LayerNormalization(
epsilon=self._norm_epsilon, dtype="float32")
......@@ -656,7 +660,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._cross_attention_cls = layers.attention.MultiHeadAttention
def build(self, input_shape):
......@@ -690,7 +695,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="output",
**common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout(
......@@ -726,7 +731,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"abc,cd->abd",
output_shape=(None, self.intermediate_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="intermediate",
**common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation(
......@@ -737,7 +742,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="output",
**common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment