Commit 448c31b6 authored by Zongwei Zhou's avatar Zongwei Zhou Committed by zongweiz
Browse files

[Transformer] Use float16 input and output for softmax in mixed-precision training

parent 49b90e86
......@@ -21,24 +21,6 @@ from __future__ import print_function
import tensorflow as tf
def _float32_softmax(logits, name=None):
"""Computes a softmax activation in float32.
When training a model using float16, softmax is still done in float32 for
numeric stability.
Args:
logits: A tensor, with any shape accepted by `tf.nn.softmax`.
Returns:
A tensor with the same dtype as `logits`.
"""
input_dtype = logits.dtype
logits = tf.cast(logits, tf.float32)
output = tf.nn.softmax(logits, name=name)
return tf.cast(output, input_dtype)
class Attention(tf.keras.layers.Layer):
"""Multi-headed attention layer."""
......@@ -166,7 +148,10 @@ class Attention(tf.keras.layers.Layer):
# Calculate dot product attention
logits = tf.matmul(q, k, transpose_b=True)
logits += bias
weights = _float32_softmax(logits, name="attention_weights")
# Note that softmax internally performs math operations using float32
# for numeric stability. When training with float16, we keep the input
# and output in float16 for better performance.
weights = tf.nn.softmax(logits, name="attention_weights")
if training:
weights = tf.nn.dropout(weights, rate=self.attention_dropout)
attention_output = tf.matmul(weights, v)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment