Commit 10673875 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[kernel] Add streaming support.

PiperOrigin-RevId: 477214841
parent 798d318f
...@@ -160,7 +160,8 @@ def causal_windowed_performer_attention(query_matrix, ...@@ -160,7 +160,8 @@ def causal_windowed_performer_attention(query_matrix,
chunk_length, chunk_length,
window_length, window_length,
window_decay=None, window_decay=None,
padding=None): padding=None,
cache=None):
"""Applies windowed causal kernel attention with query, key, value tensors. """Applies windowed causal kernel attention with query, key, value tensors.
We partition the T-length input sequence into N chunks, each of We partition the T-length input sequence into N chunks, each of
...@@ -202,10 +203,13 @@ def causal_windowed_performer_attention(query_matrix, ...@@ -202,10 +203,13 @@ def causal_windowed_performer_attention(query_matrix,
padding if padding is set to None. In the latter case, the axis dimension padding if padding is set to None. In the latter case, the axis dimension
of the query, value and key input tensors must be divisible by the of the query, value and key input tensors must be divisible by the
chunk_length. chunk_length.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
Returns: Returns:
Window causal performer attention of shape `[B, T, H, out_dim]`. Window causal performer attention of shape `[B, T, H, out_dim]`.
""" """
if cache is None: # Training
old_shape = tf.shape(value_matrix) old_shape = tf.shape(value_matrix)
query_matrix = pad_to_chunk_length(query_matrix, -3, chunk_length, padding) query_matrix = pad_to_chunk_length(query_matrix, -3, chunk_length, padding)
...@@ -239,18 +243,30 @@ def causal_windowed_performer_attention(query_matrix, ...@@ -239,18 +243,30 @@ def causal_windowed_performer_attention(query_matrix,
kp_v_winsum = weighted_window_sum(kp_v, window_length, decaying_weights) kp_v_winsum = weighted_window_sum(kp_v, window_length, decaying_weights)
k_winsum = weighted_window_sum(k_sum, window_length, decaying_weights) k_winsum = weighted_window_sum(k_sum, window_length, decaying_weights)
numerator = tf.einsum("BTCHD,BTHDO->BTCHO", chunked_query_matrix, kp_v_winsum) numerator = tf.einsum(
"BTCHD,BTHDO->BTCHO", chunked_query_matrix, kp_v_winsum)
k_winsum = tf.squeeze(k_winsum, -3) k_winsum = tf.squeeze(k_winsum, -3)
denominator = tf.einsum("BTCHD,BTHD->BTCH", chunked_query_matrix, k_winsum) denominator = tf.einsum("BTCHD,BTHD->BTCH", chunked_query_matrix, k_winsum)
denominator = tf.expand_dims(denominator, -1) + _NUMERIC_STABLER denominator = tf.expand_dims(denominator, -1) + _NUMERIC_STABLER
attention = numerator / denominator attention = numerator / denominator
attention = tf.reshape(attention, new_shape) attention = tf.reshape(attention, new_shape)
start = tf.zeros([len(old_shape)], dtype=old_shape.dtype) start = tf.zeros([len(old_shape)], dtype=old_shape.dtype)
attention = tf.slice(attention, start, old_shape) attention = tf.slice(attention, start, old_shape)
# Queued window cache (drop instead of decay) not yet supported.
else: # Streaming
if window_decay is None or window_decay > 1.0 or window_decay < 0.0:
raise ValueError("window_decay should be in (0.0, 1.0) and not None.")
kv = cache["kv"] + tf.einsum("BTHD,BTHO->BHOD", key_matrix, value_matrix)
cache["kv"] = kv * window_decay
k_sum = cache["k_sum"] + tf.reduce_sum(key_matrix, axis=1)
cache["k_sum"] = k_sum * window_decay
denominator = tf.einsum("BTHD,BHD->BTH", query_matrix, k_sum)
attention = tf.einsum("BTHD,BHOD,BTH->BTHO", query_matrix, kv,
1.0 / (denominator + _NUMERIC_STABLER))
return attention return attention
...@@ -443,7 +459,7 @@ def expplus(data_orig, ...@@ -443,7 +459,7 @@ def expplus(data_orig,
# pylint: disable=g-long-lambda # pylint: disable=g-long-lambda
_TRANSFORM_MAP = { _CAUSAL_SUPPORT_TRANSFORM_MAP = {
"elu": "elu":
functools.partial( functools.partial(
_generalized_kernel, _generalized_kernel,
...@@ -476,11 +492,19 @@ _TRANSFORM_MAP = { ...@@ -476,11 +492,19 @@ _TRANSFORM_MAP = {
h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt( h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt(
tf.cast(tf.shape(x)[-1], tf.float32))), tf.cast(tf.shape(x)[-1], tf.float32))),
), ),
"expplus":
expplus,
"identity": "identity":
functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1) functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1)
} }
_NON_CAUSAL_SUPPORT_TRANSFORM_MAP = {
"expplus": expplus,
}
_TRANSFORM_MAP = {
**_CAUSAL_SUPPORT_TRANSFORM_MAP,
**_NON_CAUSAL_SUPPORT_TRANSFORM_MAP
}
# pylint: enable=g-long-lambda # pylint: enable=g-long-lambda
...@@ -609,6 +633,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -609,6 +633,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
feature_transform, feature_transform,
is_short_seq, is_short_seq,
attention_mask=None, attention_mask=None,
cache=None,
training=False, training=False,
numeric_stabler=_NUMERIC_STABLER): numeric_stabler=_NUMERIC_STABLER):
"""Applies kernel attention with query, key, value tensors. """Applies kernel attention with query, key, value tensors.
...@@ -628,6 +653,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -628,6 +653,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads. may want to mask the output if query contains pads.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
training: Python boolean indicating whether the layer should behave in training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing). training mode (adding dropout) or in inference mode (doing nothing).
numeric_stabler: A scalar value added to avoid divide by 0. numeric_stabler: A scalar value added to avoid divide by 0.
...@@ -682,7 +709,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -682,7 +709,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
chunk_length=self.causal_chunk_length, chunk_length=self.causal_chunk_length,
window_length=self.causal_window_length, window_length=self.causal_window_length,
window_decay=self.causal_window_decay, window_decay=self.causal_window_decay,
padding=self.causal_padding) padding=self.causal_padding,
cache=cache)
else: else:
kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value) kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value)
denominator = 1.0 / ( denominator = 1.0 / (
...@@ -709,7 +737,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -709,7 +737,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
name="attention_output_softmax") name="attention_output_softmax")
self._dropout_softmax = tf.keras.layers.Dropout(rate=self._dropout) self._dropout_softmax = tf.keras.layers.Dropout(rate=self._dropout)
def call(self, query, value, key=None, attention_mask=None, training=False): def call(self, query, value, key=None, attention_mask=None, cache=None,
training=False):
"""Compute attention with kernel mechanism. """Compute attention with kernel mechanism.
Args: Args:
...@@ -720,12 +749,29 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -720,12 +749,29 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads. may want to mask the output if query contains pads.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
training: Python boolean indicating whether the layer should behave in training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing). training mode (adding dropout) or in inference mode (doing nothing).
Returns: Returns:
Multi-headed outputs of attention computation. Multi-headed outputs of attention computation.
""" """
if cache is not None:
if training:
raise ValueError(
"Cache is not supported when training is True.")
if not self.use_causal_windowed:
raise ValueError(
"Cache is not supported for non use_causal_windowed case.")
if self._begin_kernel:
raise ValueError(
"Cache is not supported when begin_kernel is set since the bahvior "
"is too complicated.")
if self._feature_transform in _NON_CAUSAL_SUPPORT_TRANSFORM_MAP:
raise ValueError("Cache is not supported for feature_transform %s" %
(self._feature_transform))
if not self._built_from_signature: if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key) self._build_from_signature(query=query, value=value, key=key)
if key is None: if key is None:
...@@ -761,7 +807,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -761,7 +807,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output = self._compute_attention(query, key, value, attention_output = self._compute_attention(query, key, value,
self._feature_transform, self._feature_transform,
self._is_short_seq, self._is_short_seq,
attention_mask, training) attention_mask,
cache,
training)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
attention_output = self._dropout_layer(attention_output) attention_output = self._dropout_layer(attention_output)
......
...@@ -30,6 +30,64 @@ _BEGIN_KERNEL = [0, 512] ...@@ -30,6 +30,64 @@ _BEGIN_KERNEL = [0, 512]
class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase): class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
# expplus is only designed for bi-directional use case.
# exp can be numeric unstable.
@parameterized.parameters(itertools.product(
["relu", "elu"], [1, 4], [0.9]))
def test_causal_windowed_attention_projection_streaming(
self, feature_transform, causal_chunk_length, causal_weight_decay):
num_heads = 12
key_dim = 64
seq_length = 16
num_chunks = seq_length // causal_chunk_length
causal_window_length = num_chunks
batch_size = 2
training = False
num_random_features = 0
test_layer = attention.KernelAttention(
num_heads=num_heads,
key_dim=key_dim,
feature_transform=feature_transform,
num_random_features=num_random_features,
redraw=False,
is_short_seq=False,
begin_kernel=False,
use_causal_windowed=True,
causal_chunk_length=causal_chunk_length,
causal_window_length=causal_window_length,
causal_window_decay=causal_weight_decay,
causal_padding=None,
)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim), seed=2)
value = query
encoder_inputs_mask = tf.ones((batch_size, seq_length), dtype=tf.int32)
masks = tf.cast(encoder_inputs_mask, dtype=tf.float32)
output = test_layer(
query=query,
value=value,
attention_mask=masks,
training=training)
kv_cache = tf.zeros(
(batch_size, num_heads, key_dim,
num_random_features if num_random_features > 0 else key_dim))
k_sum_cache = tf.zeros((batch_size, 1, key_dim))
stream_output = []
cache = {"kv": kv_cache, "k_sum": k_sum_cache}
for i in range(num_chunks):
stream_output.append(
test_layer(
query=query[:, i * causal_chunk_length:(i + 1) *
causal_chunk_length, :],
value=value[:, i * causal_chunk_length:(i + 1) *
causal_chunk_length, :],
attention_mask=masks[:, i * causal_chunk_length:(i + 1) *
causal_chunk_length],
cache=cache,
training=training))
stream_output = tf.concat(stream_output, axis=1)
self.assertAllClose(output, stream_output)
@parameterized.parameters( @parameterized.parameters(
itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False], itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False],
_IS_SHORT_SEQ, _BEGIN_KERNEL)) _IS_SHORT_SEQ, _BEGIN_KERNEL))
...@@ -196,6 +254,5 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase): ...@@ -196,6 +254,5 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
[2, 1, 2, 2, 2]), [2, 1, 2, 2, 2]),
winsum) winsum)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment