Commit 3eed3e03 authored by Krzysztof Choromanski's avatar Krzysztof Choromanski Committed by A. Unique TensorFlower
Browse files

Improving integration of the FAVOR++ mechanism with the test of the Performer's code.

PiperOrigin-RevId: 466056003
parent cb7ae42f
......@@ -49,11 +49,10 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
axis: Axis to pad along.
chunk_length: The output tensor will have shape[axis] divisible by
chunk_length.
padding: Pad the input tensor across the axis from either left or
right if padding is set to "left" or "right"; applies no padding
if padding is set to None. In the latter case, the axis
dimension of the input tensor must be divisible by the
chunk_length.
padding: Pad the input tensor across the axis from either left or right if
padding is set to "left" or "right"; applies no padding if padding is set
to None. In the latter case, the axis dimension of the input tensor must
be divisible by the chunk_length.
Returns:
Padded tensor with shape[axis] divisible by chunk_length.
......@@ -73,10 +72,11 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
else:
raise ValueError(
"Illegal padding value; must be one of \"left\", \"right\" or None.")
paddings = tf.concat(
[tf.zeros([axis, 2], dtype=tf.int32),
axis_paddings,
tf.zeros([rank - axis - 1, 2], dtype=tf.int32)], axis=0)
paddings = tf.concat([
tf.zeros([axis, 2], dtype=tf.int32), axis_paddings,
tf.zeros([rank - axis - 1, 2], dtype=tf.int32)
],
axis=0)
return tf.pad(tensor, paddings)
......@@ -94,7 +94,7 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
shape = tf.shape(tensor)
num_chunks = shape[axis] // chunk_length
new_shape = tf.concat(
[shape[:axis], [num_chunks, chunk_length], shape[(axis+1):]], axis=0)
[shape[:axis], [num_chunks, chunk_length], shape[(axis + 1):]], axis=0)
return tf.reshape(tensor, new_shape)
......@@ -128,8 +128,7 @@ def weighted_window_sum(tensor, window_length, window_weights):
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the window.
window_weights: Tensor of shape [window_length] containing window
weights.
window_weights: Tensor of shape [window_length] containing window weights.
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
......@@ -196,14 +195,13 @@ def causal_windowed_performer_attention(query_matrix,
value_matrix: Value `Tensor` of shape `[B, T, H, out_dim]`.
chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks.
window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this factor
before summation.
padding: Pad the query, value and key input tensors across the
axis from either left or right if padding is set to "left" or
"right"; apply no padding if padding is set to None. In the
latter case, the axis dimension of the query, value and key
input tensors must be divisible by the chunk_length.
window_decay: Float window decay factor or `None`. If set, exponentially
decay past attention window values by this factor before summation.
padding: Pad the query, value and key input tensors across the axis from
either left or right if padding is set to "left" or "right"; apply no
padding if padding is set to None. In the latter case, the axis dimension
of the query, value and key input tensors must be divisible by the
chunk_length.
Returns:
Window causal performer attention of shape `[B, T, H, out_dim]`.
......@@ -302,11 +300,13 @@ def create_projection_matrix(m, d, seed=None):
return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix)
def _generalized_kernel(x, projection_matrix, f, h):
def _generalized_kernel(x, y, is_query, projection_matrix, f, h):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args:
x: The feature being transformed with shape [B, T, N ,H].
y: The extra stats-tensor of shape [B, T, N ,H].
is_query: True if x is a query-tensor.
projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections.
f: A non-linear function applied on x or projected x.
......@@ -316,7 +316,8 @@ def _generalized_kernel(x, projection_matrix, f, h):
Returns:
Transformed feature.
"""
del y
del is_query
if projection_matrix is None:
return h(x) * f(x)
else:
......@@ -475,6 +476,8 @@ _TRANSFORM_MAP = {
h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt(
tf.cast(tf.shape(x)[-1], tf.float32))),
),
"expplus":
expplus,
"identity":
functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1)
}
......@@ -554,18 +557,16 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
causal_chunk_length: Length of each chunk in tokens.
causal_window_length: Length of attention window in chunks.
causal_window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this
factor before summation.
causal_padding: Pad the query, value and key input tensors
across the axis from either left or right if padding is set to
"left" or "right"; apply no padding if padding is set to None.
In the latter case, the axis dimension of the query, value and
key input tensors must be divisible by the chunk_length.
**kwargs:
The same arguments `MultiHeadAttention` layer.
exponentially decay past attention window values by this factor before
summation.
causal_padding: Pad the query, value and key input tensors across the axis
from either left or right if padding is set to "left" or "right"; apply
no padding if padding is set to None. In the latter case, the axis
dimension of the query, value and key input tensors must be divisible by
the chunk_length.
**kwargs: The same arguments `MultiHeadAttention` layer.
"""
if (feature_transform not in _TRANSFORM_MAP and
feature_transform != "expplus"):
if feature_transform not in _TRANSFORM_MAP:
raise ValueError("Unsupported feature_transform. The supported "
"feature_transform are %s. "
"Got '%s'." % (_TRANSFORM_MAP.keys(), feature_transform))
......@@ -661,12 +662,10 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
key *= tf.math.sqrt(scale)
query *= tf.math.sqrt(scale)
if feature_transform != "expplus":
key_prime = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
query_prime = _TRANSFORM_MAP[feature_transform](query, projection_matrix)
else:
key_prime = expplus(key, query, False, projection_matrix)
query_prime = expplus(query, key, True, projection_matrix)
key_prime = _TRANSFORM_MAP[feature_transform](key, query, False,
projection_matrix)
query_prime = _TRANSFORM_MAP[feature_transform](query, key, True,
projection_matrix)
if attention_mask is not None:
key_prime = tf.einsum("BSNH,BS->BSNH", key_prime, attention_mask)
......@@ -677,7 +676,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
elif self.use_causal_windowed:
attention_output = causal_windowed_performer_attention(
query_prime, key_prime, value,
query_prime,
key_prime,
value,
chunk_length=self.causal_chunk_length,
window_length=self.causal_window_length,
window_decay=self.causal_window_decay,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment