Commit a9a3bac9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 464067452
parent 8843bb24
......@@ -58,6 +58,8 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
Returns:
Padded tensor with shape[axis] divisible by chunk_length.
"""
if padding is None:
return tensor
shape = tf.shape(tensor)
rank = tf.rank(tensor)
if axis < 0:
......@@ -68,14 +70,9 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
axis_paddings = [[0, pad_length]]
elif padding == "left":
axis_paddings = [[pad_length, 0]]
elif padding is None:
if pad_length != 0:
raise ValueError("When padding is None, the axis dimension"
"has to be divisible by the chunk_length.")
return tensor
else:
raise ValueError("Illegal padding value; must be one of \"left\""
"\"right\" or None.")
raise ValueError(
"Illegal padding value; must be one of \"left\", \"right\" or None.")
paddings = tf.concat(
[tf.zeros([axis, 2], dtype=tf.int32),
axis_paddings,
......@@ -109,16 +106,18 @@ def causal_windowed_performer_attention(query_matrix,
padding=None):
"""Applies windowed causal kernel attention with query, key, value tensors.
We partition the T-length input sequence into N chunks, each of chunk_length
tokens (thus: T = N * chunk_length). Within each chunk, we apply bidirectional
(non-causal) Performers’ implicit attention and we model relationships between
different chunks using Performers’ causal attention. We consider windowed
causal variant of performer, where the current chunk attends only to the
window of window_length of the most recent chunks.
We partition the T-length input sequence into N chunks, each of
chunk_length tokens (thus: T = N * chunk_length). Within each chunk,
we apply bidirectional (non-causal) Performers’ implicit attention
and we model relationships between different chunks using
Performers’ causal attention. We consider windowed causal variant of
performer, where the current chunk attends only to the window of
window_length of the most recent chunks.
Below is an example with T=9, chunk_length=3, window_length=2. In
this example 1 indicates attention is computed between the pair
while 0 indicates attention is not computed between the pairs:
Below is an example with T=9, chunk_length=3, window_length=1. 1 indicates
attention is computed between the pair while 0 indicates attention is not
computed between the pairs:
111000000
111000000
111000000
......@@ -454,7 +453,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
scale_by_length=False,
use_causal_windowed=False,
causal_chunk_length=1,
causal_window_length=1,
causal_window_length=3,
causal_padding=None,
**kwargs):
r"""Constructor of KernelAttention.
......
......@@ -21,7 +21,7 @@ import tensorflow as tf
from official.nlp.modeling.layers import kernel_attention as attention
_FEATURE_TRANSFORM = ['relu', 'elu', 'exp', 'expplus']
_FEATURE_TRANSFORM = ["relu", "elu", "exp", "expplus"]
_REDRAW = [True, False]
_TRAINING = [True, False]
_IS_SHORT_SEQ = [True, False]
......@@ -62,10 +62,10 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False],
[0]))
[0], [None, "left", "right"]))
def test_causal_windowed_attention_projection(
self, feature_transform, num_random_features, training, redraw,
begin_kernel):
begin_kernel, causal_padding):
num_heads = 12
key_dim = 64
seq_length = 1024
......@@ -80,7 +80,8 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
begin_kernel=begin_kernel,
use_causal_windowed=True,
causal_chunk_length=8,
causal_window_length=3)
causal_window_length=3,
causal_padding=causal_padding)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim))
value = query
......@@ -150,14 +151,14 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotAllClose(output_scale_by_length, output_no_scale_by_length)
def test_unsupported_feature_transform(self):
with self.assertRaisesRegex(ValueError, 'Unsupported feature_transform.*'):
_ = attention.KernelAttention(feature_transform='test')
with self.assertRaisesRegex(ValueError, "Unsupported feature_transform.*"):
_ = attention.KernelAttention(feature_transform="test")
def test_redraw_true_no_projection(self):
with self.assertRaisesRegex(
ValueError, 'There is nothing to redraw when num_random_features.*'):
ValueError, "There is nothing to redraw when num_random_features.*"):
_ = attention.KernelAttention(
num_heads=2, key_dim=64, feature_transform='elu',
num_heads=2, key_dim=64, feature_transform="elu",
num_random_features=0, redraw=True)
def test_config(self):
......@@ -166,7 +167,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
test_layer = attention.KernelAttention(
num_heads=num_heads,
key_dim=key_dim,
feature_transform='exp',
feature_transform="exp",
num_random_features=128,
is_short_seq=True)
new_layer = attention.KernelAttention.from_config(
......@@ -174,5 +175,5 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
if __name__ == '__main__':
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment