Commit 4ad903b4 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 463764367
parent 93245b4f
...@@ -41,7 +41,7 @@ class KernelMask(tf.keras.layers.Layer): ...@@ -41,7 +41,7 @@ class KernelMask(tf.keras.layers.Layer):
return mask return mask
def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"): def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
"""Pads a tensor so that shape[axis] is divisible by chunk_length. """Pads a tensor so that shape[axis] is divisible by chunk_length.
Args: Args:
...@@ -49,9 +49,11 @@ def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"): ...@@ -49,9 +49,11 @@ def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"):
axis: Axis to pad along. axis: Axis to pad along.
chunk_length: The output tensor will have shape[axis] divisible by chunk_length: The output tensor will have shape[axis] divisible by
chunk_length. chunk_length.
pad: Pad the input tensor across the axis from left if pad="left", right if padding: Pad the input tensor across the axis from either left or
pad="right", or apply no padding if pad=None. In the latter case, the axis right if padding is set to "left" or "right"; applies no padding
dimension of the input tensor must be divisible by the chunk_length. if padding is set to None. In the latter case, the axis
dimension of the input tensor must be divisible by the
chunk_length.
Returns: Returns:
Padded tensor with shape[axis] divisible by chunk_length. Padded tensor with shape[axis] divisible by chunk_length.
...@@ -62,19 +64,23 @@ def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"): ...@@ -62,19 +64,23 @@ def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"):
axis += rank axis += rank
axis_length = shape[axis] axis_length = shape[axis]
pad_length = -axis_length % chunk_length pad_length = -axis_length % chunk_length
if pad == "right": if padding == "right":
pad_width_2 = [[0, pad_length]] axis_paddings = [[0, pad_length]]
elif pad == "left": elif padding == "left":
pad_width_2 = [[pad_length, 0]] axis_paddings = [[pad_length, 0]]
else: elif padding is None:
if pad_length != 0: if pad_length != 0:
raise ValueError("When padding is not set, the axis dimension" raise ValueError("When padding is None, the axis dimension"
"has to be divisible by the chunk_length.") "has to be divisible by the chunk_length.")
return tensor return tensor
pad_width = tf.concat( else:
[tf.zeros([axis, 2], dtype=tf.int32), pad_width_2, raise ValueError("Illegal padding value; must be one of \"left\""
"\"right\" or None.")
paddings = tf.concat(
[tf.zeros([axis, 2], dtype=tf.int32),
axis_paddings,
tf.zeros([rank - axis - 1, 2], dtype=tf.int32)], axis=0) tf.zeros([rank - axis - 1, 2], dtype=tf.int32)], axis=0)
return tf.pad(tensor, pad_width) return tf.pad(tensor, paddings)
def split_tensor_into_chunks(tensor, axis, chunk_length): def split_tensor_into_chunks(tensor, axis, chunk_length):
...@@ -95,12 +101,12 @@ def split_tensor_into_chunks(tensor, axis, chunk_length): ...@@ -95,12 +101,12 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
return tf.reshape(tensor, new_shape) return tf.reshape(tensor, new_shape)
def windowed_causal_performer_attention(query_matrix, def causal_windowed_performer_attention(query_matrix,
key_matrix, key_matrix,
value_matrix, value_matrix,
chunk_length, chunk_length,
window_length, window_length,
pad="right"): padding=None):
"""Applies windowed causal kernel attention with query, key, value tensors. """Applies windowed causal kernel attention with query, key, value tensors.
We partition the T-length input sequence into N chunks, each of chunk_length We partition the T-length input sequence into N chunks, each of chunk_length
...@@ -113,19 +119,19 @@ def windowed_causal_performer_attention(query_matrix, ...@@ -113,19 +119,19 @@ def windowed_causal_performer_attention(query_matrix,
Below is an example with T=9, chunk_length=3, window_length=1. 1 indicates Below is an example with T=9, chunk_length=3, window_length=1. 1 indicates
attention is computed between the pair while 0 indicates attention is not attention is computed between the pair while 0 indicates attention is not
computed between the pairs: computed between the pairs:
111000000 111000000
111000000 111000000
111000000 111000000
111111000 111111000
111111000 111111000
111111000 111111000
000111111 000111111
000111111 000111111
000111111 000111111
User can ensure sequence_length is divisible by chunk_length or use User can ensure sequence_length is divisible by chunk_length or use
pad="left"/"right" to pad the sequence length either at the top or bottom padding="left"/"right" to pad the sequence length either at the left
respectively and make it divisible by chunk_length. or right respectively and make it divisible by chunk_length.
Args: Args:
query_matrix: Kernel query `Tensor` of shape `[B, T, N, dim]`. query_matrix: Kernel query `Tensor` of shape `[B, T, N, dim]`.
...@@ -133,20 +139,20 @@ def windowed_causal_performer_attention(query_matrix, ...@@ -133,20 +139,20 @@ def windowed_causal_performer_attention(query_matrix,
value_matrix: Value `Tensor` of shape `[B, T, N, out_dim]`. value_matrix: Value `Tensor` of shape `[B, T, N, out_dim]`.
chunk_length: Length of each chunk in tokens. chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks. window_length: Length of attention window in chunks.
pad: Pad the query, value and key input tensors across the T dimension from padding: Pad the query, value and key input tensors across the
left if pad="left", right if pad="right", or apply no padding if pad=None. axis from either left or right if padding is set to "left" or
In the latter case, the T dimension of the input tensors must be divisible "right"; apply no padding if padding is set to None. In the
by the chunk_length. latter case, the axis dimension of the query, value and key
input tensors must be divisible by the chunk_length.
Returns: Returns:
Window causal performer attention of shape `[B, T, N, out_dim]`. Window causal performer attention of shape `[B, T, N, out_dim]`.
""" """
old_shape = tf.shape(value_matrix) old_shape = tf.shape(value_matrix)
query_matrix = pad_to_chunk_length(query_matrix, -3, chunk_length, pad) query_matrix = pad_to_chunk_length(query_matrix, -3, chunk_length, padding)
key_matrix = pad_to_chunk_length(key_matrix, -3, chunk_length, pad) key_matrix = pad_to_chunk_length(key_matrix, -3, chunk_length, padding)
value_matrix = pad_to_chunk_length(value_matrix, -3, chunk_length, pad) value_matrix = pad_to_chunk_length(value_matrix, -3, chunk_length, padding)
new_shape = tf.shape(value_matrix) new_shape = tf.shape(value_matrix)
chunked_query_matrix = split_tensor_into_chunks( chunked_query_matrix = split_tensor_into_chunks(
...@@ -446,16 +452,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -446,16 +452,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
begin_kernel=0, begin_kernel=0,
scale=None, scale=None,
scale_by_length=False, scale_by_length=False,
use_windowed_causal=False, use_causal_windowed=False,
chunk_length=1, causal_chunk_length=1,
window_length=3, causal_window_length=1,
causal_padding=None,
**kwargs): **kwargs):
r"""Constructor of KernelAttention. r"""Constructor of KernelAttention.
Args: Args:
feature_transform: A non-linear transform of the keys and quries. Possible feature_transform: A non-linear transform of the keys and queries.
transforms are "elu", "relu", "square", "exp", "expplus", "expmod", Possible transforms are "elu", "relu", "square", "exp", "expplus",
"identity". "expmod", "identity".
num_random_features: Number of random features to be used for projection. num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform. if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the seed: The seed to begin drawing random features. Once the seed is set, the
...@@ -475,11 +482,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -475,11 +482,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
the dot product based on key length. Set as log_512^(n) to stablize the dot product based on key length. Set as log_512^(n) to stablize
attention entropy against length. Refer to attention entropy against length. Refer to
https://kexue.fm/archives/8823 for details. https://kexue.fm/archives/8823 for details.
use_windowed_causal: If true perform windowed causal attention. See use_causal_windowed: If true perform windowed causal attention. See
windowed_causal_performer_attention function docstring for more details. causal_windowed_performer_attention function docstring for more details.
chunk_length: Length of each chunk in tokens. causal_chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks. causal_window_length: Length of attention window in chunks.
**kwargs: The same arguments `MultiHeadAttention` layer. causal_padding: Pad the query, value and key input tensors
across the axis from either left or right if padding is set to
"left" or "right"; apply no padding if padding is set to None.
In the latter case, the axis dimension of the query, value and
key input tensors must be divisible by the chunk_length.
**kwargs:
The same arguments `MultiHeadAttention` layer.
""" """
if (feature_transform not in _TRANSFORM_MAP and if (feature_transform not in _TRANSFORM_MAP and
feature_transform != "expplus"): feature_transform != "expplus"):
...@@ -509,12 +522,13 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -509,12 +522,13 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self._projection_matrix = create_projection_matrix( self._projection_matrix = create_projection_matrix(
self._num_random_features, self._key_dim, self._num_random_features, self._key_dim,
tf.constant([self._seed, self._seed + 1])) tf.constant([self._seed, self._seed + 1]))
self.use_windowed_causal = use_windowed_causal self.use_causal_windowed = use_causal_windowed
self.chunk_length = chunk_length self.causal_chunk_length = causal_chunk_length
self.window_length = window_length self.causal_window_length = causal_window_length
if self.use_windowed_causal and self._is_short_seq: self.causal_padding = causal_padding
if self.use_causal_windowed and self._is_short_seq:
raise ValueError( raise ValueError(
"use_windowed_causal and short_seq methods are mutually exclusive") "use_causal_windowed and short_seq methods are mutually exclusive")
def _compute_attention(self, def _compute_attention(self,
query, query,
...@@ -590,9 +604,12 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -590,9 +604,12 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query_prime, key_prime) attention_scores = tf.einsum("BTNH,BSNH->BTSN", query_prime, key_prime)
attention_scores = tf.nn.softmax(attention_scores, axis=2) attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value) attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
elif self.use_windowed_causal: elif self.use_causal_windowed:
attention_output = windowed_causal_performer_attention( attention_output = causal_windowed_performer_attention(
query_prime, key_prime, value, self.chunk_length, self.window_length) query_prime, key_prime, value,
chunk_length=self.causal_chunk_length,
window_length=self.causal_window_length,
padding=self.causal_padding)
else: else:
kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value) kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value)
denominator = 1.0 / ( denominator = 1.0 / (
......
...@@ -63,7 +63,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase): ...@@ -63,7 +63,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(
itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False], itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False],
[0])) [0]))
def test_windowed_causal_attention_projection( def test_causal_windowed_attention_projection(
self, feature_transform, num_random_features, training, redraw, self, feature_transform, num_random_features, training, redraw,
begin_kernel): begin_kernel):
num_heads = 12 num_heads = 12
...@@ -78,9 +78,9 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase): ...@@ -78,9 +78,9 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
redraw=redraw, redraw=redraw,
is_short_seq=False, is_short_seq=False,
begin_kernel=begin_kernel, begin_kernel=begin_kernel,
use_windowed_causal=True, use_causal_windowed=True,
chunk_length=8, causal_chunk_length=8,
window_length=3) causal_window_length=3)
query = tf.random.normal( query = tf.random.normal(
shape=(batch_size, seq_length, key_dim)) shape=(batch_size, seq_length, key_dim))
value = query value = query
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment