Commit cebc9e90 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[nlp][translation] Consistently use float32 for layer norm for both encoder and decoder.

PiperOrigin-RevId: 363233691
parent 0537226c
......@@ -232,7 +232,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon))
epsilon=self._norm_epsilon,
dtype="float32"))
# Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads,
......@@ -250,7 +251,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon))
epsilon=self._norm_epsilon,
dtype="float32"))
# Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
......@@ -273,7 +275,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
**common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon)
name="output_layer_norm", axis=-1,
epsilon=self._norm_epsilon, dtype="float32")
super().build(input_shape)
def get_config(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment