"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "b1adde8e667e623e6812792912c6ed750bbb4fa3"
Commit 0257b276 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Simply LayerNorm mixed precision logic.

Instead of needing to ensure variables are float32, casting inputs to float32, etc, instead dtype="float32" is passed to the layer constructor, which will do all that logic automatically.

The only difference is the output of LayerNorm is now float32 instead of float16, so an extra cast is needed elsewhere.

PiperOrigin-RevId: 273833286
parent 3980d2a1
...@@ -290,6 +290,7 @@ class Transformer(tf.keras.Model): ...@@ -290,6 +290,7 @@ class Transformer(tf.keras.Model):
def predict(self, encoder_outputs, encoder_decoder_attention_bias, training): def predict(self, encoder_outputs, encoder_decoder_attention_bias, training):
"""Return predicted sequence.""" """Return predicted sequence."""
encoder_outputs = tf.cast(encoder_outputs, self.params["dtype"])
if self.params["padded_decode"]: if self.params["padded_decode"]:
batch_size = encoder_outputs.shape.as_list()[0] batch_size = encoder_outputs.shape.as_list()[0]
input_length = encoder_outputs.shape.as_list()[1] input_length = encoder_outputs.shape.as_list()[1]
...@@ -356,27 +357,21 @@ class LayerNormalization(tf.keras.layers.Layer): ...@@ -356,27 +357,21 @@ class LayerNormalization(tf.keras.layers.Layer):
"""Applies layer normalization.""" """Applies layer normalization."""
def __init__(self, hidden_size): def __init__(self, hidden_size):
super(LayerNormalization, self).__init__() # Pass dtype=float32, as we have not yet tested if layer norm is numerically
# stable in float16 and bfloat16.
super(LayerNormalization, self).__init__(dtype="float32")
self.hidden_size = hidden_size self.hidden_size = hidden_size
def build(self, input_shape): def build(self, input_shape):
"""Builds the layer.""" """Builds the layer."""
# Passing experimental_autocast=False causes these variables to not be
# automatically casted to fp16 when mixed precision is used. Since we use
# float32 in call() for numeric stability, we do not want variables to be
# casted to fp16.
self.scale = self.add_weight( self.scale = self.add_weight(
"layer_norm_scale", "layer_norm_scale",
shape=[self.hidden_size], shape=[self.hidden_size],
dtype="float32", initializer=tf.ones_initializer())
initializer=tf.ones_initializer(),
experimental_autocast=False)
self.bias = self.add_weight( self.bias = self.add_weight(
"layer_norm_bias", "layer_norm_bias",
shape=[self.hidden_size], shape=[self.hidden_size],
dtype="float32", initializer=tf.zeros_initializer())
initializer=tf.zeros_initializer(),
experimental_autocast=False)
super(LayerNormalization, self).build(input_shape) super(LayerNormalization, self).build(input_shape)
def get_config(self): def get_config(self):
...@@ -385,13 +380,10 @@ class LayerNormalization(tf.keras.layers.Layer): ...@@ -385,13 +380,10 @@ class LayerNormalization(tf.keras.layers.Layer):
} }
def call(self, x, epsilon=1e-6): def call(self, x, epsilon=1e-6):
input_dtype = x.dtype
if input_dtype == tf.float16 or input_dtype == tf.bfloat16:
x = tf.cast(x, tf.float32)
mean = tf.reduce_mean(x, axis=[-1], keepdims=True) mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True) variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)
norm_x = (x - mean) * tf.math.rsqrt(variance + epsilon) norm_x = (x - mean) * tf.math.rsqrt(variance + epsilon)
return tf.cast(norm_x * self.scale + self.bias, input_dtype) return norm_x * self.scale + self.bias
class PrePostProcessingWrapper(tf.keras.layers.Layer): class PrePostProcessingWrapper(tf.keras.layers.Layer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment