"research/attention_ocr/python/demo_inference.py" did not exist on "da341f70faaade5dbdc854be04bb13ea9c777909"
Commit 448c31b6 authored by Zongwei Zhou's avatar Zongwei Zhou Committed by zongweiz
Browse files

[Transformer] Use float16 input and output for softmax in mixed-precision training

parent 49b90e86
...@@ -21,24 +21,6 @@ from __future__ import print_function ...@@ -21,24 +21,6 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
def _float32_softmax(logits, name=None):
"""Computes a softmax activation in float32.
When training a model using float16, softmax is still done in float32 for
numeric stability.
Args:
logits: A tensor, with any shape accepted by `tf.nn.softmax`.
Returns:
A tensor with the same dtype as `logits`.
"""
input_dtype = logits.dtype
logits = tf.cast(logits, tf.float32)
output = tf.nn.softmax(logits, name=name)
return tf.cast(output, input_dtype)
class Attention(tf.keras.layers.Layer): class Attention(tf.keras.layers.Layer):
"""Multi-headed attention layer.""" """Multi-headed attention layer."""
...@@ -166,7 +148,10 @@ class Attention(tf.keras.layers.Layer): ...@@ -166,7 +148,10 @@ class Attention(tf.keras.layers.Layer):
# Calculate dot product attention # Calculate dot product attention
logits = tf.matmul(q, k, transpose_b=True) logits = tf.matmul(q, k, transpose_b=True)
logits += bias logits += bias
weights = _float32_softmax(logits, name="attention_weights") # Note that softmax internally performs math operations using float32
# for numeric stability. When training with float16, we keep the input
# and output in float16 for better performance.
weights = tf.nn.softmax(logits, name="attention_weights")
if training: if training:
weights = tf.nn.dropout(weights, rate=self.attention_dropout) weights = tf.nn.dropout(weights, rate=self.attention_dropout)
attention_output = tf.matmul(weights, v) attention_output = tf.matmul(weights, v)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment