Commit d51cc280 authored by Jialu Liu's avatar Jialu Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 441337748
parent 54659689
...@@ -178,13 +178,13 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -178,13 +178,13 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
is_short_seq=False, is_short_seq=False,
begin_kernel=0, begin_kernel=0,
scale=None, scale=None,
scale_by_length=False,
**kwargs): **kwargs):
r"""Constructor of KernelAttention. r"""Constructor of KernelAttention.
Args: Args:
feature_transform: A non-linear transform of the keys and quries. feature_transform: A non-linear transform of the keys and quries. Possible
Possible transforms are "elu", "relu", "square", "exp", "expmod", transforms are "elu", "relu", "square", "exp", "expmod", "identity".
"identity".
num_random_features: Number of random features to be used for projection. num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform. if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the seed: The seed to begin drawing random features. Once the seed is set, the
...@@ -194,12 +194,16 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -194,12 +194,16 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
redraw: Whether to redraw projection every forward pass during training. redraw: Whether to redraw projection every forward pass during training.
The argument is only effective when num_random_features > 0. The argument is only effective when num_random_features > 0.
is_short_seq: boolean predicate indicating whether input data consists of is_short_seq: boolean predicate indicating whether input data consists of
very short sequences or not; in most cases this should be False very short sequences or not; in most cases this should be False (default
(default option). option).
begin_kernel: Apply kernel_attention after this sequence id and apply begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this. softmax attention before this.
scale: The value to scale the dot product as described in `Attention Is scale: The value to scale the dot product as described in `Attention Is
All You Need`. If None, we use 1/sqrt(dk) as described in the paper. All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
scale_by_length: boolean predicate indicating whether additionally scale
the dot product based on key length. Set as log_512^(n) to stablize
attention entropy against length. Refer to
https://kexue.fm/archives/8823 for details.
**kwargs: The same arguments `MultiHeadAttention` layer. **kwargs: The same arguments `MultiHeadAttention` layer.
""" """
if feature_transform not in _TRANSFORM_MAP: if feature_transform not in _TRANSFORM_MAP:
...@@ -214,6 +218,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -214,6 +218,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self._redraw = redraw self._redraw = redraw
self._is_short_seq = is_short_seq self._is_short_seq = is_short_seq
self._begin_kernel = begin_kernel self._begin_kernel = begin_kernel
self._scale_by_length = scale_by_length
# We use the seed for two scenarios: # We use the seed for two scenarios:
# 1. inference # 1. inference
# 2. no redraw # 2. no redraw
...@@ -252,9 +257,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -252,9 +257,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
is_short_seq: boolean predicate indicating whether input data consists of is_short_seq: boolean predicate indicating whether input data consists of
short or long sequences; usually short sequence is defined as having short or long sequences; usually short sequence is defined as having
length L <= 1024. length L <= 1024.
attention_mask: a boolean mask of shape `[B, S]`, that prevents attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
attenting to masked positions. Note that the mask is only appied to to masked positions. Note that the mask is only appied to the keys. User
the keys. User may want to mask the output if query contains pads. may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing). training mode (adding dropout) or in inference mode (doing nothing).
numeric_stabler: A scalar value added to avoid divide by 0. numeric_stabler: A scalar value added to avoid divide by 0.
...@@ -270,17 +275,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -270,17 +275,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else: else:
projection_matrix = self._projection_matrix projection_matrix = self._projection_matrix
if self._scale_by_length:
scale = tf.math.log(tf.reduce_sum(attention_mask,
axis=-1)) * self._scale / math.log(512)
scale = tf.reshape(scale, [-1, 1, 1, 1])
else:
scale = self._scale
if is_short_seq: if is_short_seq:
# Note: Applying scalar multiply at the smaller end of einsum improves # Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in # XLA performance, but may introduce slight numeric differences in
# the Transformer attention head. # the Transformer attention head.
query = query * self._scale query = query * scale
else: else:
# Note: we suspect spliting the scale to key, query yields smaller # Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used. # approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection. # For simplicity, we also split when there's no random projection.
key *= math.sqrt(self._scale) key *= tf.math.sqrt(scale)
query *= math.sqrt(self._scale) query *= tf.math.sqrt(scale)
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix) key = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix) query = _TRANSFORM_MAP[feature_transform](query, projection_matrix)
...@@ -330,9 +341,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -330,9 +341,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
value: Value `Tensor` of shape `[B, S, dim]`. value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case. `value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, S]`, that prevents attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
attenting to masked positions. Note that the mask is only appied to to masked positions. Note that the mask is only appied to the keys. User
the keys. User may want to mask the output if query contains pads. may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing). training mode (adding dropout) or in inference mode (doing nothing).
...@@ -373,9 +384,10 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -373,9 +384,10 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output = tf.concat( attention_output = tf.concat(
[attention_output_softmax, attention_output_kernel], axis=1) [attention_output_softmax, attention_output_kernel], axis=1)
else: else:
attention_output = self._compute_attention( attention_output = self._compute_attention(query, key, value,
query, key, value, self._feature_transform, self._feature_transform,
self._is_short_seq, attention_mask, training) self._is_short_seq,
attention_mask, training)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
attention_output = self._dropout_layer(attention_output) attention_output = self._dropout_layer(attention_output)
......
...@@ -30,9 +30,9 @@ _BEGIN_KERNEL = [0, 512] ...@@ -30,9 +30,9 @@ _BEGIN_KERNEL = [0, 512]
class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase): class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(itertools.product( @parameterized.parameters(
_FEATURE_TRANSFORM, [127], _TRAINING, [True, False], itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False],
_IS_SHORT_SEQ, _BEGIN_KERNEL)) _IS_SHORT_SEQ, _BEGIN_KERNEL))
def test_attention_projection( def test_attention_projection(
self, feature_transform, num_random_features, training, redraw, is_short, self, feature_transform, num_random_features, training, redraw, is_short,
begin_kernel): begin_kernel):
...@@ -90,6 +90,32 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase): ...@@ -90,6 +90,32 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
training=training) training=training)
self.assertEqual(output.shape, [batch_size, seq_length, key_dim]) self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
@parameterized.parameters([128, 512])
def test_attention_scale_by_length(self, seq_length):
num_heads = 12
key_dim = 64
batch_size = 2
test_layer = attention.KernelAttention(
num_heads=num_heads,
key_dim=key_dim,
num_random_features=0,
scale_by_length=True)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim))
value = query
encoder_inputs_mask = tf.ones((batch_size, seq_length), dtype=tf.int32)
masks = tf.cast(encoder_inputs_mask, dtype=tf.float32)
output_scale_by_length = test_layer(
query=query, value=value, attention_mask=masks)
test_layer._scale_by_length = False
output_no_scale_by_length = test_layer(
query=query, value=value, attention_mask=masks)
if seq_length == 512: # Equals because log(seq_length, base=512) = 1.0
self.assertAllClose(output_scale_by_length, output_no_scale_by_length)
else:
self.assertNotAllClose(output_scale_by_length, output_no_scale_by_length)
def test_unsupported_feature_transform(self): def test_unsupported_feature_transform(self):
with self.assertRaisesRegex(ValueError, 'Unsupported feature_transform.*'): with self.assertRaisesRegex(ValueError, 'Unsupported feature_transform.*'):
_ = attention.KernelAttention(feature_transform='test') _ = attention.KernelAttention(feature_transform='test')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment