Commit cdda0906 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 322216928
parent d3d2ad3d
...@@ -366,12 +366,16 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -366,12 +366,16 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention_output: Multi-headed outputs of attention computation. attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights. attention_scores: Multi-headed attention weights.
""" """
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query_tensor = tf.multiply(query_tensor,
1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor, attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor) query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S] # `attention_scores` = [B, N, T, S]
......
...@@ -152,7 +152,10 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -152,7 +152,10 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_ = new_layer([input_data, mask_data]) _ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data]) new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :]) self.assertAllClose(new_output_tensor,
output_tensor[:, 0:1, :],
atol=5e-5,
rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls): def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment