Commit 5f23689e authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[efficient] Fix is_short_seq order so that we can also apply feature transform...

[efficient] Fix is_short_seq order so that we can also apply feature transform and apply softmax afterwards.

PiperOrigin-RevId: 383967806
parent dbaec326
...@@ -134,11 +134,13 @@ _TRANSFORM_MAP = { ...@@ -134,11 +134,13 @@ _TRANSFORM_MAP = {
functools.partial( functools.partial(
_generalized_kernel, _generalized_kernel,
# Avoid exp explosion by shifting. # Avoid exp explosion by shifting.
f=lambda x: tf.math.exp( f=lambda x: tf.math.exp(x - tf.math.reduce_max(
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)), x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp( h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt(
-0.5 * tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))),), tf.cast(tf.shape(x)[-1], tf.float32))),
"identity": lambda x, projection_matrix, is_query: x ),
"identity":
functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1)
} }
# pylint: enable=g-long-lambda # pylint: enable=g-long-lambda
...@@ -260,18 +262,6 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -260,18 +262,6 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Returns: Returns:
attention_output: Multi-headed outputs of attention computation. attention_output: Multi-headed outputs of attention computation.
""" """
if is_short_seq:
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = query * self._scale
if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key)
attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
return attention_output
projection_matrix = None projection_matrix = None
if self._num_random_features > 0: if self._num_random_features > 0:
if self._redraw and training: if self._redraw and training:
...@@ -280,11 +270,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -280,11 +270,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else: else:
projection_matrix = self._projection_matrix projection_matrix = self._projection_matrix
# Note: we suspect spliting the scale to key, query yields smaller if is_short_seq:
# approximation variance when random projection is used. # Note: Applying scalar multiply at the smaller end of einsum improves
# For simplicity, we also split when there's no random projection. # XLA performance, but may introduce slight numeric differences in
key *= math.sqrt(self._scale) # the Transformer attention head.
query *= math.sqrt(self._scale) query = query * self._scale
else:
# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key *= math.sqrt(self._scale)
query *= math.sqrt(self._scale)
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix) key = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix) query = _TRANSFORM_MAP[feature_transform](query, projection_matrix)
...@@ -292,11 +288,18 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -292,11 +288,18 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
if attention_mask is not None: if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask) key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)
kv = tf.einsum("BSNH,BSND->BNDH", key, value) if is_short_seq:
denominator = 1.0 / ( attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key)
tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) + attention_scores = tf.nn.softmax(attention_scores, axis=2)
_NUMERIC_STABLER) attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
return tf.einsum("BTNH,BNDH,BTN->BTND", query, kv, denominator) else:
kv = tf.einsum("BSNH,BSND->BNDH", key, value)
denominator = 1.0 / (
tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) +
_NUMERIC_STABLER)
attention_output = tf.einsum(
"BTNH,BNDH,BTN->BTND", query, kv, denominator)
return attention_output
def _build_from_signature(self, query, value, key=None): def _build_from_signature(self, query, value, key=None):
super()._build_from_signature(query=query, value=value, key=key) super()._build_from_signature(query=query, value=value, key=key)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment