Commit 7aab3475 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[efficient] Fix is_short_seq order so that we can also apply feature transform...

[efficient] Fix is_short_seq order so that we can also apply feature transform and apply softmax afterwards.

PiperOrigin-RevId: 383967806
parent 9a052f52
......@@ -134,11 +134,13 @@ _TRANSFORM_MAP = {
functools.partial(
_generalized_kernel,
# Avoid exp explosion by shifting.
f=lambda x: tf.math.exp(
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp(
-0.5 * tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))),),
"identity": lambda x, projection_matrix, is_query: x
f=lambda x: tf.math.exp(x - tf.math.reduce_max(
x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt(
tf.cast(tf.shape(x)[-1], tf.float32))),
),
"identity":
functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1)
}
# pylint: enable=g-long-lambda
......@@ -260,18 +262,6 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Returns:
attention_output: Multi-headed outputs of attention computation.
"""
if is_short_seq:
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = query * self._scale
if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key)
attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
return attention_output
projection_matrix = None
if self._num_random_features > 0:
if self._redraw and training:
......@@ -280,6 +270,12 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else:
projection_matrix = self._projection_matrix
if is_short_seq:
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = query * self._scale
else:
# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
......@@ -292,11 +288,18 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)
if is_short_seq:
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key)
attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
else:
kv = tf.einsum("BSNH,BSND->BNDH", key, value)
denominator = 1.0 / (
tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) +
_NUMERIC_STABLER)
return tf.einsum("BTNH,BNDH,BTN->BTND", query, kv, denominator)
attention_output = tf.einsum(
"BTNH,BNDH,BTN->BTND", query, kv, denominator)
return attention_output
def _build_from_signature(self, query, value, key=None):
super()._build_from_signature(query=query, value=value, key=key)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment