Commit cebc9e90 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[nlp][translation] Consistently use float32 for layer norm for both encoder and decoder.

PiperOrigin-RevId: 363233691
parent 0537226c
...@@ -232,7 +232,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -232,7 +232,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", name="self_attention_layer_norm",
axis=-1, axis=-1,
epsilon=self._norm_epsilon)) epsilon=self._norm_epsilon,
dtype="float32"))
# Encoder-decoder attention. # Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls( self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
...@@ -250,7 +251,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -250,7 +251,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm", name="attention/encdec_output_layer_norm",
axis=-1, axis=-1,
epsilon=self._norm_epsilon)) epsilon=self._norm_epsilon,
dtype="float32"))
# Feed-forward projection. # Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense( self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
...@@ -273,7 +275,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -273,7 +275,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
**common_kwargs) **common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization( self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon) name="output_layer_norm", axis=-1,
epsilon=self._norm_epsilon, dtype="float32")
super().build(input_shape) super().build(input_shape)
def get_config(self): def get_config(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment