Commit 3eed3e03 authored by Krzysztof Choromanski's avatar Krzysztof Choromanski Committed by A. Unique TensorFlower
Browse files

Improving integration of the FAVOR++ mechanism with the test of the Performer's code.

PiperOrigin-RevId: 466056003
parent cb7ae42f
...@@ -49,11 +49,10 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None): ...@@ -49,11 +49,10 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
axis: Axis to pad along. axis: Axis to pad along.
chunk_length: The output tensor will have shape[axis] divisible by chunk_length: The output tensor will have shape[axis] divisible by
chunk_length. chunk_length.
padding: Pad the input tensor across the axis from either left or padding: Pad the input tensor across the axis from either left or right if
right if padding is set to "left" or "right"; applies no padding padding is set to "left" or "right"; applies no padding if padding is set
if padding is set to None. In the latter case, the axis to None. In the latter case, the axis dimension of the input tensor must
dimension of the input tensor must be divisible by the be divisible by the chunk_length.
chunk_length.
Returns: Returns:
Padded tensor with shape[axis] divisible by chunk_length. Padded tensor with shape[axis] divisible by chunk_length.
...@@ -73,10 +72,11 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None): ...@@ -73,10 +72,11 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
else: else:
raise ValueError( raise ValueError(
"Illegal padding value; must be one of \"left\", \"right\" or None.") "Illegal padding value; must be one of \"left\", \"right\" or None.")
paddings = tf.concat( paddings = tf.concat([
[tf.zeros([axis, 2], dtype=tf.int32), tf.zeros([axis, 2], dtype=tf.int32), axis_paddings,
axis_paddings, tf.zeros([rank - axis - 1, 2], dtype=tf.int32)
tf.zeros([rank - axis - 1, 2], dtype=tf.int32)], axis=0) ],
axis=0)
return tf.pad(tensor, paddings) return tf.pad(tensor, paddings)
...@@ -94,7 +94,7 @@ def split_tensor_into_chunks(tensor, axis, chunk_length): ...@@ -94,7 +94,7 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
shape = tf.shape(tensor) shape = tf.shape(tensor)
num_chunks = shape[axis] // chunk_length num_chunks = shape[axis] // chunk_length
new_shape = tf.concat( new_shape = tf.concat(
[shape[:axis], [num_chunks, chunk_length], shape[(axis+1):]], axis=0) [shape[:axis], [num_chunks, chunk_length], shape[(axis + 1):]], axis=0)
return tf.reshape(tensor, new_shape) return tf.reshape(tensor, new_shape)
...@@ -128,8 +128,7 @@ def weighted_window_sum(tensor, window_length, window_weights): ...@@ -128,8 +128,7 @@ def weighted_window_sum(tensor, window_length, window_weights):
Args: Args:
tensor: Tensor of shape `[B, T', C', H, dim]`. tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the window. window_length: The length of the window.
window_weights: Tensor of shape [window_length] containing window window_weights: Tensor of shape [window_length] containing window weights.
weights.
Returns: Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the A tensor of shape [B, T', C', H, dim] containing sums over the
...@@ -196,14 +195,13 @@ def causal_windowed_performer_attention(query_matrix, ...@@ -196,14 +195,13 @@ def causal_windowed_performer_attention(query_matrix,
value_matrix: Value `Tensor` of shape `[B, T, H, out_dim]`. value_matrix: Value `Tensor` of shape `[B, T, H, out_dim]`.
chunk_length: Length of each chunk in tokens. chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks. window_length: Length of attention window in chunks.
window_decay: Float window decay factor or `None`. If set, window_decay: Float window decay factor or `None`. If set, exponentially
exponentially decay past attention window values by this factor decay past attention window values by this factor before summation.
before summation. padding: Pad the query, value and key input tensors across the axis from
padding: Pad the query, value and key input tensors across the either left or right if padding is set to "left" or "right"; apply no
axis from either left or right if padding is set to "left" or padding if padding is set to None. In the latter case, the axis dimension
"right"; apply no padding if padding is set to None. In the of the query, value and key input tensors must be divisible by the
latter case, the axis dimension of the query, value and key chunk_length.
input tensors must be divisible by the chunk_length.
Returns: Returns:
Window causal performer attention of shape `[B, T, H, out_dim]`. Window causal performer attention of shape `[B, T, H, out_dim]`.
...@@ -302,11 +300,13 @@ def create_projection_matrix(m, d, seed=None): ...@@ -302,11 +300,13 @@ def create_projection_matrix(m, d, seed=None):
return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix) return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix)
def _generalized_kernel(x, projection_matrix, f, h): def _generalized_kernel(x, y, is_query, projection_matrix, f, h):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS. """Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args: Args:
x: The feature being transformed with shape [B, T, N ,H]. x: The feature being transformed with shape [B, T, N ,H].
y: The extra stats-tensor of shape [B, T, N ,H].
is_query: True if x is a query-tensor.
projection_matrix: The matrix with shape [M, H] that we projecct x to, where projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections. M is the number of projections.
f: A non-linear function applied on x or projected x. f: A non-linear function applied on x or projected x.
...@@ -316,7 +316,8 @@ def _generalized_kernel(x, projection_matrix, f, h): ...@@ -316,7 +316,8 @@ def _generalized_kernel(x, projection_matrix, f, h):
Returns: Returns:
Transformed feature. Transformed feature.
""" """
del y
del is_query
if projection_matrix is None: if projection_matrix is None:
return h(x) * f(x) return h(x) * f(x)
else: else:
...@@ -475,6 +476,8 @@ _TRANSFORM_MAP = { ...@@ -475,6 +476,8 @@ _TRANSFORM_MAP = {
h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt( h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt(
tf.cast(tf.shape(x)[-1], tf.float32))), tf.cast(tf.shape(x)[-1], tf.float32))),
), ),
"expplus":
expplus,
"identity": "identity":
functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1) functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1)
} }
...@@ -554,18 +557,16 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -554,18 +557,16 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
causal_chunk_length: Length of each chunk in tokens. causal_chunk_length: Length of each chunk in tokens.
causal_window_length: Length of attention window in chunks. causal_window_length: Length of attention window in chunks.
causal_window_decay: Float window decay factor or `None`. If set, causal_window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this exponentially decay past attention window values by this factor before
factor before summation. summation.
causal_padding: Pad the query, value and key input tensors causal_padding: Pad the query, value and key input tensors across the axis
across the axis from either left or right if padding is set to from either left or right if padding is set to "left" or "right"; apply
"left" or "right"; apply no padding if padding is set to None. no padding if padding is set to None. In the latter case, the axis
In the latter case, the axis dimension of the query, value and dimension of the query, value and key input tensors must be divisible by
key input tensors must be divisible by the chunk_length. the chunk_length.
**kwargs: **kwargs: The same arguments `MultiHeadAttention` layer.
The same arguments `MultiHeadAttention` layer.
""" """
if (feature_transform not in _TRANSFORM_MAP and if feature_transform not in _TRANSFORM_MAP:
feature_transform != "expplus"):
raise ValueError("Unsupported feature_transform. The supported " raise ValueError("Unsupported feature_transform. The supported "
"feature_transform are %s. " "feature_transform are %s. "
"Got '%s'." % (_TRANSFORM_MAP.keys(), feature_transform)) "Got '%s'." % (_TRANSFORM_MAP.keys(), feature_transform))
...@@ -661,12 +662,10 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -661,12 +662,10 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
key *= tf.math.sqrt(scale) key *= tf.math.sqrt(scale)
query *= tf.math.sqrt(scale) query *= tf.math.sqrt(scale)
if feature_transform != "expplus": key_prime = _TRANSFORM_MAP[feature_transform](key, query, False,
key_prime = _TRANSFORM_MAP[feature_transform](key, projection_matrix) projection_matrix)
query_prime = _TRANSFORM_MAP[feature_transform](query, projection_matrix) query_prime = _TRANSFORM_MAP[feature_transform](query, key, True,
else: projection_matrix)
key_prime = expplus(key, query, False, projection_matrix)
query_prime = expplus(query, key, True, projection_matrix)
if attention_mask is not None: if attention_mask is not None:
key_prime = tf.einsum("BSNH,BS->BSNH", key_prime, attention_mask) key_prime = tf.einsum("BSNH,BS->BSNH", key_prime, attention_mask)
...@@ -677,7 +676,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -677,7 +676,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value) attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
elif self.use_causal_windowed: elif self.use_causal_windowed:
attention_output = causal_windowed_performer_attention( attention_output = causal_windowed_performer_attention(
query_prime, key_prime, value, query_prime,
key_prime,
value,
chunk_length=self.causal_chunk_length, chunk_length=self.causal_chunk_length,
window_length=self.causal_window_length, window_length=self.causal_window_length,
window_decay=self.causal_window_decay, window_decay=self.causal_window_decay,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment