Commit 3fa65606 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Bert fp16 perf improvements, do the matmul in intermediate later in fp16, and...

Bert fp16 perf improvements, do the matmul in intermediate later in fp16, and also remove explicit casting to fp32 for layerNorm.

PiperOrigin-RevId: 273379063
parent 41293260
...@@ -637,12 +637,14 @@ class Dense2DProjection(tf.keras.layers.Layer): ...@@ -637,12 +637,14 @@ class Dense2DProjection(tf.keras.layers.Layer):
kernel_initializer=None, kernel_initializer=None,
bias_initializer="zeros", bias_initializer="zeros",
activation=None, activation=None,
fp32_activation=False,
**kwargs): **kwargs):
super(Dense2DProjection, self).__init__(**kwargs) super(Dense2DProjection, self).__init__(**kwargs)
self.output_size = output_size self.output_size = output_size
self.kernel_initializer = kernel_initializer self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer self.bias_initializer = bias_initializer
self.activation = activation self.activation = activation
self.fp32_activation = fp32_activation
def build(self, input_shape): def build(self, input_shape):
"""Implements build() for the layer.""" """Implements build() for the layer."""
...@@ -685,6 +687,8 @@ class Dense2DProjection(tf.keras.layers.Layer): ...@@ -685,6 +687,8 @@ class Dense2DProjection(tf.keras.layers.Layer):
ret = tf.einsum("abc,cd->abd", inputs, self.kernel) ret = tf.einsum("abc,cd->abd", inputs, self.kernel)
ret += self.bias ret += self.bias
if self.activation is not None: if self.activation is not None:
if self.dtype == tf.float16 and self.fp32_activation:
ret = tf.cast(ret, tf.float32)
return self.activation(ret) return self.activation(ret)
return ret return ret
...@@ -753,7 +757,7 @@ class TransformerBlock(tf.keras.layers.Layer): ...@@ -753,7 +757,7 @@ class TransformerBlock(tf.keras.layers.Layer):
kernel_initializer=get_initializer(self.initializer_range), kernel_initializer=get_initializer(self.initializer_range),
activation=self.intermediate_activation, activation=self.intermediate_activation,
# Uses float32 so that gelu activation is done in float32. # Uses float32 so that gelu activation is done in float32.
dtype=tf.float32, fp32_activation=True,
name="intermediate") name="intermediate")
self.output_dense = Dense2DProjection( self.output_dense = Dense2DProjection(
output_size=self.hidden_size, output_size=self.hidden_size,
...@@ -788,23 +792,16 @@ class TransformerBlock(tf.keras.layers.Layer): ...@@ -788,23 +792,16 @@ class TransformerBlock(tf.keras.layers.Layer):
attention_output = self.attention_dropout(attention_output) attention_output = self.attention_dropout(attention_output)
# Use float32 in keras layer norm and the gelu activation in the # Use float32 in keras layer norm and the gelu activation in the
# intermediate dense layer for numeric stability # intermediate dense layer for numeric stability
# TODO(reedwm): These casts are probably unnecessary, as we passed
# dtype=tf.float32 to the layer norm constructor, so it will cast its inputs
# to float32 automatically. These manual casts additionally do the "+"
# operator in float32, but "+" is numerically stable in float16.
if self.float_type == tf.float16:
input_tensor = tf.cast(input_tensor, tf.float32)
attention_output = tf.cast(attention_output, tf.float32)
attention_output = self.attention_layer_norm(input_tensor + attention_output = self.attention_layer_norm(input_tensor +
attention_output) attention_output)
if self.float_type == tf.float16:
attention_output = tf.cast(attention_output, tf.float16)
intermediate_output = self.intermediate_dense(attention_output) intermediate_output = self.intermediate_dense(attention_output)
if self.float_type == tf.float16: if self.float_type == tf.float16:
intermediate_output = tf.cast(intermediate_output, tf.float16) intermediate_output = tf.cast(intermediate_output, tf.float16)
layer_output = self.output_dense(intermediate_output) layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output) layer_output = self.output_dropout(layer_output)
# Use float32 in keras layer norm for numeric stability # Use float32 in keras layer norm for numeric stability
if self.float_type == tf.float16:
layer_output = tf.cast(layer_output, tf.float32)
layer_output = self.output_layer_norm(layer_output + attention_output) layer_output = self.output_layer_norm(layer_output + attention_output)
if self.float_type == tf.float16: if self.float_type == tf.float16:
layer_output = tf.cast(layer_output, tf.float16) layer_output = tf.cast(layer_output, tf.float16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment