"test/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "1def3fa99867e45728c908e2ca557eb02681aad1"
Commit 0b579232 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Implement CachedAttention layer. This is useful for decoders.

PiperOrigin-RevId: 282439719
parent bcce419a
......@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Layers package definition."""
from official.nlp.modeling.layers.attention import Attention
from official.nlp.modeling.layers.attention import * # pylint: disable=wildcard-import
from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
......
......@@ -119,7 +119,7 @@ class Attention(tf.keras.layers.Layer):
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def compute_output_shape(self, input_shape):
# TODO(momernick): validate tensor dimensioos
# TODO(momernick): validate tensor dimensions.
from_tensor_shape = tf.TensorShape(input_shape[0])
batch = from_tensor_shape[0]
from_tensor_length = from_tensor_shape[1]
......@@ -188,3 +188,85 @@ class Attention(tf.keras.layers.Layer):
# `context_layer` = [B, F, N, H]
return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor)
@tf.keras.utils.register_keras_serializable(package="Text")
class CachedAttention(Attention):
"""Attention layer with cache used for auto-agressive decoding.
Attributes:
num_heads: Number of attention heads.
head_size: Size of each attention head.
**kwargs: Other keyword arguments inherit from `Attention` class.
"""
def __init__(self, num_heads, head_size, **kwargs):
super(CachedAttention, self).__init__(num_heads, head_size, **kwargs)
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step):
"""Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values.
if decode_loop_step is not None:
# TPU special case.
key_seq_dim = cache["key"].shape.as_list()[1]
indices = tf.reshape(
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype),
[1, key_seq_dim, 1, 1])
key_tensor = cache["key"] + key_tensor * indices
value_seq_dim = cache["value"].shape.as_list()[1]
indices = tf.reshape(
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype),
[1, value_seq_dim, 1, 1])
value_tensor = cache["value"] + value_tensor * indices
else:
key_tensor = tf.concat(
[tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1)
value_tensor = tf.concat(
[tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1)
# Update cache
cache["key"] = key_tensor
cache["value"] = value_tensor
return key_tensor, value_tensor
def call(self, inputs, decode_loop_step=None):
from_tensor = inputs[0]
to_tensor = inputs[1]
attention_mask = inputs[2] if len(inputs) >= 3 else None
cache = inputs[3] if len(inputs) >= 4 else None
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
# `key_tensor` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor)
# `value_tensor` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor)
if cache:
key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor,
cache, decode_loop_step)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._head_size)))
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
attention_probs = self._masked_softmax([attention_scores, attention_mask])
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, F, N, H]
return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor), cache
......@@ -88,5 +88,70 @@ class AttentionLayerTest(keras_parameterized.TestCase):
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64])
if __name__ == '__main__':
def _create_cache(batch_size, init_decode_length, num_heads, head_size):
return {
"key":
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32),
"value":
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32)
}
@keras_parameterized.run_all_keras_modes
class CachedAttentionTest(keras_parameterized.TestCase):
def test_masked_attention(self):
"""Test with a mask tensor."""
num_heads, head_size = 2, 2
# Create a 3-dimensional input (the first dimension is implicit).
from_seq_length = 4
batch_size = 3
# GPU/CPU case.
init_decode_length = 0
# Directly tests the keras layer.
cache = _create_cache(batch_size, init_decode_length, num_heads, head_size)
layer = attention.CachedAttention(num_heads=num_heads, head_size=head_size)
# Generate data for the input (non-mask) tensors.
from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32)
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length))
masked_output_data, cache = layer([from_data, from_data, mask_data, cache])
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
# Tests inputs without cache.
masked_output_data, cache = layer([from_data, from_data, mask_data])
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2))
self.assertIsNone(cache)
def test_padded_decode(self):
"""Test with a mask tensor."""
num_heads, head_size = 2, 2
from_seq_length = 4
# TPU decoding should pre-allocate the entire sequence.
batch_size = 3
init_decode_length = from_seq_length
# Directly tests the keras layer.
cache = _create_cache(batch_size, init_decode_length, num_heads, head_size)
layer = attention.CachedAttention(num_heads=num_heads, head_size=head_size)
# Generate data for the input (non-mask) tensors.
from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32)
decode_loop_step = 2
mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32)
# Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data, cache = layer([from_data, from_data, mask_data, cache],
decode_loop_step=decode_loop_step)
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment