"sgl-kernel/python/vscode:/vscode.git/clone" did not exist on "3db35c1af4a91479da526d8d3b1adadcfdeba054"
Commit d51cc280 authored by Jialu Liu's avatar Jialu Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 441337748
parent 54659689
......@@ -178,13 +178,13 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
is_short_seq=False,
begin_kernel=0,
scale=None,
scale_by_length=False,
**kwargs):
r"""Constructor of KernelAttention.
Args:
feature_transform: A non-linear transform of the keys and quries.
Possible transforms are "elu", "relu", "square", "exp", "expmod",
"identity".
feature_transform: A non-linear transform of the keys and quries. Possible
transforms are "elu", "relu", "square", "exp", "expmod", "identity".
num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the
......@@ -194,12 +194,16 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
redraw: Whether to redraw projection every forward pass during training.
The argument is only effective when num_random_features > 0.
is_short_seq: boolean predicate indicating whether input data consists of
very short sequences or not; in most cases this should be False
(default option).
very short sequences or not; in most cases this should be False (default
option).
begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this.
scale: The value to scale the dot product as described in `Attention Is
All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
scale_by_length: boolean predicate indicating whether additionally scale
the dot product based on key length. Set as log_512^(n) to stablize
attention entropy against length. Refer to
https://kexue.fm/archives/8823 for details.
**kwargs: The same arguments `MultiHeadAttention` layer.
"""
if feature_transform not in _TRANSFORM_MAP:
......@@ -214,6 +218,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self._redraw = redraw
self._is_short_seq = is_short_seq
self._begin_kernel = begin_kernel
self._scale_by_length = scale_by_length
# We use the seed for two scenarios:
# 1. inference
# 2. no redraw
......@@ -252,9 +257,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
is_short_seq: boolean predicate indicating whether input data consists of
short or long sequences; usually short sequence is defined as having
length L <= 1024.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting to masked positions. Note that the mask is only appied to
the keys. User may want to mask the output if query contains pads.
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
numeric_stabler: A scalar value added to avoid divide by 0.
......@@ -270,17 +275,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else:
projection_matrix = self._projection_matrix
if self._scale_by_length:
scale = tf.math.log(tf.reduce_sum(attention_mask,
axis=-1)) * self._scale / math.log(512)
scale = tf.reshape(scale, [-1, 1, 1, 1])
else:
scale = self._scale
if is_short_seq:
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = query * self._scale
query = query * scale
else:
# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key *= math.sqrt(self._scale)
query *= math.sqrt(self._scale)
key *= tf.math.sqrt(scale)
query *= tf.math.sqrt(scale)
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix)
......@@ -330,9 +341,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting to masked positions. Note that the mask is only appied to
the keys. User may want to mask the output if query contains pads.
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
......@@ -373,9 +384,10 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output = tf.concat(
[attention_output_softmax, attention_output_kernel], axis=1)
else:
attention_output = self._compute_attention(
query, key, value, self._feature_transform,
self._is_short_seq, attention_mask, training)
attention_output = self._compute_attention(query, key, value,
self._feature_transform,
self._is_short_seq,
attention_mask, training)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_output = self._dropout_layer(attention_output)
......
......@@ -30,8 +30,8 @@ _BEGIN_KERNEL = [0, 512]
class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(itertools.product(
_FEATURE_TRANSFORM, [127], _TRAINING, [True, False],
@parameterized.parameters(
itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False],
_IS_SHORT_SEQ, _BEGIN_KERNEL))
def test_attention_projection(
self, feature_transform, num_random_features, training, redraw, is_short,
......@@ -90,6 +90,32 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
training=training)
self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
@parameterized.parameters([128, 512])
def test_attention_scale_by_length(self, seq_length):
num_heads = 12
key_dim = 64
batch_size = 2
test_layer = attention.KernelAttention(
num_heads=num_heads,
key_dim=key_dim,
num_random_features=0,
scale_by_length=True)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim))
value = query
encoder_inputs_mask = tf.ones((batch_size, seq_length), dtype=tf.int32)
masks = tf.cast(encoder_inputs_mask, dtype=tf.float32)
output_scale_by_length = test_layer(
query=query, value=value, attention_mask=masks)
test_layer._scale_by_length = False
output_no_scale_by_length = test_layer(
query=query, value=value, attention_mask=masks)
if seq_length == 512: # Equals because log(seq_length, base=512) = 1.0
self.assertAllClose(output_scale_by_length, output_no_scale_by_length)
else:
self.assertNotAllClose(output_scale_by_length, output_no_scale_by_length)
def test_unsupported_feature_transform(self):
with self.assertRaisesRegex(ValueError, 'Unsupported feature_transform.*'):
_ = attention.KernelAttention(feature_transform='test')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment