Commit 9a3e9518 authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Prepare for upcoming keras initializer change.

PiperOrigin-RevId: 451474703
parent a945ae9d
...@@ -20,6 +20,7 @@ cross-attention layer. ...@@ -20,6 +20,7 @@ cross-attention layer.
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import models from official.nlp.modeling import models
...@@ -91,8 +92,9 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -91,8 +92,9 @@ class TransformerEncoder(tf.keras.layers.Layer):
norm_first=self._norm_first, norm_first=self._norm_first,
norm_epsilon=self._norm_epsilon, norm_epsilon=self._norm_epsilon,
inner_dropout=self._intermediate_dropout, inner_dropout=self._intermediate_dropout,
attention_initializer=models.seq2seq_transformer attention_initializer=tf_utils.clone_initializer(
.attention_initializer(input_shape[2]), models.seq2seq_transformer.attention_initializer(
input_shape[2])),
name=("layer_%d" % i))) name=("layer_%d" % i)))
self.output_normalization = tf.keras.layers.LayerNormalization( self.output_normalization = tf.keras.layers.LayerNormalization(
epsilon=self._norm_epsilon, dtype="float32") epsilon=self._norm_epsilon, dtype="float32")
...@@ -234,7 +236,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -234,7 +236,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes self._attention_axes = attention_axes
def build(self, input_shape): def build(self, input_shape):
...@@ -284,7 +287,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -284,7 +287,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
einsum_equation, einsum_equation,
output_shape=(None, self._inner_dim), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
...@@ -302,7 +305,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -302,7 +305,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
**common_kwargs) **common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout) self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
...@@ -490,8 +493,9 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -490,8 +493,9 @@ class TransformerDecoder(tf.keras.layers.Layer):
norm_first=self._norm_first, norm_first=self._norm_first,
norm_epsilon=self._norm_epsilon, norm_epsilon=self._norm_epsilon,
intermediate_dropout=self._intermediate_dropout, intermediate_dropout=self._intermediate_dropout,
attention_initializer=models.seq2seq_transformer attention_initializer=tf_utils.clone_initializer(
.attention_initializer(input_shape[2]), models.seq2seq_transformer.attention_initializer(
input_shape[2])),
name=("layer_%d" % i))) name=("layer_%d" % i)))
self.output_normalization = tf.keras.layers.LayerNormalization( self.output_normalization = tf.keras.layers.LayerNormalization(
epsilon=self._norm_epsilon, dtype="float32") epsilon=self._norm_epsilon, dtype="float32")
...@@ -656,7 +660,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -656,7 +660,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._cross_attention_cls = layers.attention.MultiHeadAttention self._cross_attention_cls = layers.attention.MultiHeadAttention
def build(self, input_shape): def build(self, input_shape):
...@@ -690,7 +695,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -690,7 +695,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="output", name="output",
**common_kwargs) **common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout( self.self_attention_dropout = tf.keras.layers.Dropout(
...@@ -726,7 +731,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -726,7 +731,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self.intermediate_size), output_shape=(None, self.intermediate_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation( self.intermediate_activation_layer = tf.keras.layers.Activation(
...@@ -737,7 +742,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -737,7 +742,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="output", name="output",
**common_kwargs) **common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment