Commit d1fca260 authored by Avi Dubey's avatar Avi Dubey Committed by A. Unique TensorFlower
Browse files

windowed causal performer

PiperOrigin-RevId: 463429471
parent 1db7588c
...@@ -41,6 +41,148 @@ class KernelMask(tf.keras.layers.Layer): ...@@ -41,6 +41,148 @@ class KernelMask(tf.keras.layers.Layer):
return mask return mask
def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"):
"""Pads a tensor so that shape[axis] is divisible by chunk_length.
Args:
tensor: Input tensor to pad.
axis: Axis to pad along.
chunk_length: The output tensor will have shape[axis] divisible by
chunk_length.
pad: Pad the input tensor across the axis from left if pad="left", right if
pad="right", or apply no padding if pad=None. In the latter case, the axis
dimension of the input tensor must be divisible by the chunk_length.
Returns:
Padded tensor with shape[axis] divisible by chunk_length.
"""
shape = tf.shape(tensor)
rank = tf.rank(tensor)
if axis < 0:
axis += rank
axis_length = shape[axis]
pad_length = -axis_length % chunk_length
if pad == "right":
pad_width_2 = [[0, pad_length]]
elif pad == "left":
pad_width_2 = [[pad_length, 0]]
else:
if pad_length != 0:
raise ValueError("When padding is not set, the axis dimension"
"has to be divisible by the chunk_length.")
return tensor
pad_width = tf.concat(
[tf.zeros([axis, 2], dtype=tf.int32), pad_width_2,
tf.zeros([rank - axis - 1, 2], dtype=tf.int32)], axis=0)
return tf.pad(tensor, pad_width)
def split_tensor_into_chunks(tensor, axis, chunk_length):
"""Reshape tensor along given axis using chunk_length.
Args:
tensor: Input tensor.
axis: Reshape tensor along this axis.
chunk_length: Split the axis into [axis/chunk_length, chunk_length]
Returns:
Reshaped tensor.
"""
shape = tf.shape(tensor)
num_chunks = shape[axis] // chunk_length
new_shape = tf.concat(
[shape[:axis], [num_chunks, chunk_length], shape[(axis+1):]], axis=0)
return tf.reshape(tensor, new_shape)
def windowed_causal_performer_attention(query_matrix,
key_matrix,
value_matrix,
chunk_length,
window_length,
pad="right"):
"""Applies windowed causal kernel attention with query, key, value tensors.
We partition the T-length input sequence into N chunks, each of chunk_length
tokens (thus: T = N * chunk_length). Within each chunk, we apply bidirectional
(non-causal) Performers’ implicit attention and we model relationships between
different chunks using Performers’ causal attention. We consider windowed
causal variant of performer, where the current chunk attends only to the
window of window_length of the most recent chunks.
Below is an example with T=9, chunk_length=3, window_length=1. 1 indicates
attention is computed between the pair while 0 indicates attention is not
computed between the pairs:
111000000
111000000
111000000
111111000
111111000
111111000
000111111
000111111
000111111
User can ensure sequence_length is divisible by chunk_length or use
pad="left"/"right" to pad the sequence length either at the top or bottom
respectively and make it divisible by chunk_length.
Args:
query_matrix: Kernel query `Tensor` of shape `[B, T, N, dim]`.
key_matrix: Kernel key `Tensor` of shape `[B, T, N, dim]`.
value_matrix: Value `Tensor` of shape `[B, T, N, out_dim]`.
chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks.
pad: Pad the query, value and key input tensors across the T dimension from
left if pad="left", right if pad="right", or apply no padding if pad=None.
In the latter case, the T dimension of the input tensors must be divisible
by the chunk_length.
Returns:
Window causal performer attention of shape `[B, T, N, out_dim]`.
"""
old_shape = tf.shape(value_matrix)
query_matrix = pad_to_chunk_length(query_matrix, -3, chunk_length, pad)
key_matrix = pad_to_chunk_length(key_matrix, -3, chunk_length, pad)
value_matrix = pad_to_chunk_length(value_matrix, -3, chunk_length, pad)
new_shape = tf.shape(value_matrix)
chunked_query_matrix = split_tensor_into_chunks(
query_matrix, -3,
chunk_length) # [-1, T//chunk_length, chunk_length, N, dim]
chunked_key_matrix = split_tensor_into_chunks(
key_matrix, -3,
chunk_length) # [-1, T//chunk_length, chunk_length, N, dim]
chunked_value_matrix = split_tensor_into_chunks(
value_matrix, -3,
chunk_length) # [-1, T//chunk_length, chunk_length, N, out_dim]
kp_v = tf.einsum("BNCHD,BNCHO->BNHDO", chunked_key_matrix,
chunked_value_matrix)
kp_v_cumsum = tf.cumsum(kp_v, axis=-4)
kp_v_winsum = kp_v_cumsum - tf.pad(
kp_v_cumsum,
[[0, 0], [window_length, 0], [0, 0], [0, 0], [0, 0]])[:, :-window_length]
numerator = tf.einsum("BNCHD,BNHDO->BNCHO", chunked_query_matrix, kp_v_winsum)
k_sum = tf.reduce_sum(chunked_key_matrix, axis=-3)
k_cumsum = tf.cumsum(k_sum, axis=-3)
k_winsum = k_cumsum - tf.pad(k_cumsum, [[0, 0], [window_length, 0], [0, 0],
[0, 0]])[:, :-window_length]
denominator = tf.einsum("BNCHD,BNHD->BNCH", chunked_query_matrix, k_winsum)
denominator = tf.expand_dims(denominator, -1) + _NUMERIC_STABLER
attention = numerator / denominator
attention = tf.reshape(attention, new_shape)
start = tf.zeros([len(old_shape)], dtype=old_shape.dtype)
attention = tf.slice(attention, start, old_shape)
return attention
def create_projection_matrix(m, d, seed=None): def create_projection_matrix(m, d, seed=None):
r"""Constructs the matrix of random projections. r"""Constructs the matrix of random projections.
...@@ -304,6 +446,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -304,6 +446,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
begin_kernel=0, begin_kernel=0,
scale=None, scale=None,
scale_by_length=False, scale_by_length=False,
use_windowed_causal=False,
chunk_length=1,
window_length=3,
**kwargs): **kwargs):
r"""Constructor of KernelAttention. r"""Constructor of KernelAttention.
...@@ -330,9 +475,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -330,9 +475,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
the dot product based on key length. Set as log_512^(n) to stablize the dot product based on key length. Set as log_512^(n) to stablize
attention entropy against length. Refer to attention entropy against length. Refer to
https://kexue.fm/archives/8823 for details. https://kexue.fm/archives/8823 for details.
use_windowed_causal: If true perform windowed causal attention. See
windowed_causal_performer_attention function docstring for more details.
chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks.
**kwargs: The same arguments `MultiHeadAttention` layer. **kwargs: The same arguments `MultiHeadAttention` layer.
""" """
if feature_transform not in _TRANSFORM_MAP and feature_transform != "expplus": if (feature_transform not in _TRANSFORM_MAP and
feature_transform != "expplus"):
raise ValueError("Unsupported feature_transform. The supported " raise ValueError("Unsupported feature_transform. The supported "
"feature_transform are %s. " "feature_transform are %s. "
"Got '%s'." % (_TRANSFORM_MAP.keys(), feature_transform)) "Got '%s'." % (_TRANSFORM_MAP.keys(), feature_transform))
...@@ -359,6 +509,12 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -359,6 +509,12 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self._projection_matrix = create_projection_matrix( self._projection_matrix = create_projection_matrix(
self._num_random_features, self._key_dim, self._num_random_features, self._key_dim,
tf.constant([self._seed, self._seed + 1])) tf.constant([self._seed, self._seed + 1]))
self.use_windowed_causal = use_windowed_causal
self.chunk_length = chunk_length
self.window_length = window_length
if self.use_windowed_causal and self._is_short_seq:
raise ValueError(
"use_windowed_causal and short_seq methods are mutually exclusive")
def _compute_attention(self, def _compute_attention(self,
query, query,
...@@ -394,6 +550,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -394,6 +550,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output: Multi-headed outputs of attention computation. attention_output: Multi-headed outputs of attention computation.
""" """
projection_matrix = None projection_matrix = None
if self._num_random_features > 0: if self._num_random_features > 0:
if self._redraw and training: if self._redraw and training:
projection_matrix = create_projection_matrix(self._num_random_features, projection_matrix = create_projection_matrix(self._num_random_features,
...@@ -433,6 +590,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -433,6 +590,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query_prime, key_prime) attention_scores = tf.einsum("BTNH,BSNH->BTSN", query_prime, key_prime)
attention_scores = tf.nn.softmax(attention_scores, axis=2) attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value) attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
elif self.use_windowed_causal:
attention_output = windowed_causal_performer_attention(
query_prime, key_prime, value, self.chunk_length, self.window_length)
else: else:
kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value) kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value)
denominator = 1.0 / ( denominator = 1.0 / (
......
...@@ -60,6 +60,39 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase): ...@@ -60,6 +60,39 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
training=training) training=training)
self.assertEqual(output.shape, [batch_size, seq_length, key_dim]) self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
@parameterized.parameters(
itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False],
[0]))
def test_windowed_causal_attention_projection(
self, feature_transform, num_random_features, training, redraw,
begin_kernel):
num_heads = 12
key_dim = 64
seq_length = 1024
batch_size = 2
test_layer = attention.KernelAttention(
num_heads=num_heads,
key_dim=key_dim,
feature_transform=feature_transform,
num_random_features=num_random_features,
redraw=redraw,
is_short_seq=False,
begin_kernel=begin_kernel,
use_windowed_causal=True,
chunk_length=8,
window_length=3)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim))
value = query
encoder_inputs_mask = tf.zeros((batch_size, seq_length), dtype=tf.int32)
masks = tf.cast(encoder_inputs_mask, dtype=tf.float32)
output = test_layer(
query=query,
value=value,
attention_mask=masks,
training=training)
self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
@parameterized.parameters(itertools.product( @parameterized.parameters(itertools.product(
_FEATURE_TRANSFORM, [0], _TRAINING, [False], _FEATURE_TRANSFORM, [0], _TRAINING, [False],
_IS_SHORT_SEQ, _BEGIN_KERNEL)) _IS_SHORT_SEQ, _BEGIN_KERNEL))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment