".github/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "af336d66944689e1eacd20a2a040e5fc56c31045"
Commit 0028cbed authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 465060096
parent c2ddc0bd
...@@ -98,11 +98,69 @@ def split_tensor_into_chunks(tensor, axis, chunk_length): ...@@ -98,11 +98,69 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
return tf.reshape(tensor, new_shape) return tf.reshape(tensor, new_shape)
def rectangular_window_sum(tensor, window_length):
"""Summarizes tensor elements over a sliding rectangular window.
Sums elements of the input tensor of shape [B, T', C', H, dim]
across a rectangular window sliding along the dimension T'.
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the rectangular window.
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
window.
"""
tensor_cumsum = tf.cumsum(tensor, axis=-4)
tensor_winsum = tensor_cumsum - tf.pad(
tensor_cumsum,
[[0, 0], [window_length, 0], [0, 0], [0, 0], [0, 0]])[:, :-window_length]
return tensor_winsum
def weighted_window_sum(tensor, window_length, window_weights):
"""Summarizes tensor elements over a sliding weighted window.
Computes a weighted sum of elements of the input tensor of shape [B,
T', C', H, dim] across a window sliding along the dimension T'.
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the window.
window_weights: Tensor of shape [window_length] containing window
weights.
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
window.
"""
# Flatten the last three dimensions of the [B, T', C', H, dim] shape
# into a single channels dimension.
tensor_shape = tf.shape(tensor)
tensor_2d = tf.reshape(tensor, [tensor_shape[0], tensor_shape[1], 1, -1])
# Apply the same weights to all channels.
conv_filter = tf.tile(
tf.reshape(window_weights, [-1, 1, 1, 1]),
multiples=[1, 1, tf.shape(tensor_2d)[-1], 1])
tensor_winsum_2d = tf.nn.depthwise_conv2d(
tensor_2d,
conv_filter,
strides=[1, 1, 1, 1],
padding=[[0, 0], [window_length - 1, 0], [0, 0], [0, 0]])
# Unflatten the channels dimension into the original shape.
tensor_winsum = tf.reshape(tensor_winsum_2d, tensor_shape)
return tensor_winsum
def causal_windowed_performer_attention(query_matrix, def causal_windowed_performer_attention(query_matrix,
key_matrix, key_matrix,
value_matrix, value_matrix,
chunk_length, chunk_length,
window_length, window_length,
window_decay=None,
padding=None): padding=None):
"""Applies windowed causal kernel attention with query, key, value tensors. """Applies windowed causal kernel attention with query, key, value tensors.
...@@ -133,11 +191,14 @@ def causal_windowed_performer_attention(query_matrix, ...@@ -133,11 +191,14 @@ def causal_windowed_performer_attention(query_matrix,
or right respectively and make it divisible by chunk_length. or right respectively and make it divisible by chunk_length.
Args: Args:
query_matrix: Kernel query `Tensor` of shape `[B, T, N, dim]`. query_matrix: Kernel query `Tensor` of shape `[B, T, H, dim]`.
key_matrix: Kernel key `Tensor` of shape `[B, T, N, dim]`. key_matrix: Kernel key `Tensor` of shape `[B, T, H, dim]`.
value_matrix: Value `Tensor` of shape `[B, T, N, out_dim]`. value_matrix: Value `Tensor` of shape `[B, T, H, out_dim]`.
chunk_length: Length of each chunk in tokens. chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks. window_length: Length of attention window in chunks.
window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this factor
before summation.
padding: Pad the query, value and key input tensors across the padding: Pad the query, value and key input tensors across the
axis from either left or right if padding is set to "left" or axis from either left or right if padding is set to "left" or
"right"; apply no padding if padding is set to None. In the "right"; apply no padding if padding is set to None. In the
...@@ -145,7 +206,7 @@ def causal_windowed_performer_attention(query_matrix, ...@@ -145,7 +206,7 @@ def causal_windowed_performer_attention(query_matrix,
input tensors must be divisible by the chunk_length. input tensors must be divisible by the chunk_length.
Returns: Returns:
Window causal performer attention of shape `[B, T, N, out_dim]`. Window causal performer attention of shape `[B, T, H, out_dim]`.
""" """
old_shape = tf.shape(value_matrix) old_shape = tf.shape(value_matrix)
...@@ -164,19 +225,26 @@ def causal_windowed_performer_attention(query_matrix, ...@@ -164,19 +225,26 @@ def causal_windowed_performer_attention(query_matrix,
value_matrix, -3, value_matrix, -3,
chunk_length) # [-1, T//chunk_length, chunk_length, N, out_dim] chunk_length) # [-1, T//chunk_length, chunk_length, N, out_dim]
kp_v = tf.einsum("BNCHD,BNCHO->BNHDO", chunked_key_matrix, kp_v = tf.einsum("BTCHD,BTCHO->BTHDO", chunked_key_matrix,
chunked_value_matrix) chunked_value_matrix)
kp_v_cumsum = tf.cumsum(kp_v, axis=-4)
kp_v_winsum = kp_v_cumsum - tf.pad(
kp_v_cumsum,
[[0, 0], [window_length, 0], [0, 0], [0, 0], [0, 0]])[:, :-window_length]
numerator = tf.einsum("BNCHD,BNHDO->BNCHO", chunked_query_matrix, kp_v_winsum)
k_sum = tf.reduce_sum(chunked_key_matrix, axis=-3) k_sum = tf.math.reduce_sum(chunked_key_matrix, axis=-3, keepdims=True)
k_cumsum = tf.cumsum(k_sum, axis=-3)
k_winsum = k_cumsum - tf.pad(k_cumsum, [[0, 0], [window_length, 0], [0, 0], if window_decay is None:
[0, 0]])[:, :-window_length] kp_v_winsum = rectangular_window_sum(kp_v, window_length)
denominator = tf.einsum("BNCHD,BNHD->BNCH", chunked_query_matrix, k_winsum) k_winsum = rectangular_window_sum(k_sum, window_length)
else:
# Compute exponentially decaying weights.
decaying_weights = tf.math.pow(
tf.convert_to_tensor(window_decay, dtype=value_matrix.dtype),
tf.range(window_length - 1, -1, delta=-1, dtype=value_matrix.dtype))
kp_v_winsum = weighted_window_sum(kp_v, window_length, decaying_weights)
k_winsum = weighted_window_sum(k_sum, window_length, decaying_weights)
numerator = tf.einsum("BTCHD,BTHDO->BTCHO", chunked_query_matrix, kp_v_winsum)
k_winsum = tf.squeeze(k_winsum, -3)
denominator = tf.einsum("BTCHD,BTHD->BTCH", chunked_query_matrix, k_winsum)
denominator = tf.expand_dims(denominator, -1) + _NUMERIC_STABLER denominator = tf.expand_dims(denominator, -1) + _NUMERIC_STABLER
attention = numerator / denominator attention = numerator / denominator
...@@ -351,7 +419,6 @@ def expplus(data_orig, ...@@ -351,7 +419,6 @@ def expplus(data_orig,
diag_omega = tf.expand_dims(diag_omega, axis=0) diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = tf.expand_dims(diag_omega, axis=0) diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = a_coeff * diag_omega diag_omega = a_coeff * diag_omega
#
if numerical_renormalizer: if numerical_renormalizer:
if is_query: if is_query:
...@@ -454,6 +521,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -454,6 +521,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
use_causal_windowed=False, use_causal_windowed=False,
causal_chunk_length=1, causal_chunk_length=1,
causal_window_length=3, causal_window_length=3,
causal_window_decay=None,
causal_padding=None, causal_padding=None,
**kwargs): **kwargs):
r"""Constructor of KernelAttention. r"""Constructor of KernelAttention.
...@@ -485,6 +553,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -485,6 +553,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
causal_windowed_performer_attention function docstring for more details. causal_windowed_performer_attention function docstring for more details.
causal_chunk_length: Length of each chunk in tokens. causal_chunk_length: Length of each chunk in tokens.
causal_window_length: Length of attention window in chunks. causal_window_length: Length of attention window in chunks.
causal_window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this
factor before summation.
causal_padding: Pad the query, value and key input tensors causal_padding: Pad the query, value and key input tensors
across the axis from either left or right if padding is set to across the axis from either left or right if padding is set to
"left" or "right"; apply no padding if padding is set to None. "left" or "right"; apply no padding if padding is set to None.
...@@ -524,6 +595,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -524,6 +595,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self.use_causal_windowed = use_causal_windowed self.use_causal_windowed = use_causal_windowed
self.causal_chunk_length = causal_chunk_length self.causal_chunk_length = causal_chunk_length
self.causal_window_length = causal_window_length self.causal_window_length = causal_window_length
self.causal_window_decay = causal_window_decay
self.causal_padding = causal_padding self.causal_padding = causal_padding
if self.use_causal_windowed and self._is_short_seq: if self.use_causal_windowed and self._is_short_seq:
raise ValueError( raise ValueError(
...@@ -608,6 +680,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -608,6 +680,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
query_prime, key_prime, value, query_prime, key_prime, value,
chunk_length=self.causal_chunk_length, chunk_length=self.causal_chunk_length,
window_length=self.causal_window_length, window_length=self.causal_window_length,
window_decay=self.causal_window_decay,
padding=self.causal_padding) padding=self.causal_padding)
else: else:
kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value) kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value)
......
...@@ -61,11 +61,11 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase): ...@@ -61,11 +61,11 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(output.shape, [batch_size, seq_length, key_dim]) self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
@parameterized.parameters( @parameterized.parameters(
itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False], itertools.product(["relu", "exp"], [127], _TRAINING, [True, False],
[0], [None, "left", "right"])) [0], [None, 0.97], [None, "left", "right"]))
def test_causal_windowed_attention_projection( def test_causal_windowed_attention_projection(
self, feature_transform, num_random_features, training, redraw, self, feature_transform, num_random_features, training, redraw,
begin_kernel, causal_padding): begin_kernel, causal_window_decay, causal_padding):
num_heads = 12 num_heads = 12
key_dim = 64 key_dim = 64
seq_length = 1024 seq_length = 1024
...@@ -81,6 +81,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase): ...@@ -81,6 +81,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
use_causal_windowed=True, use_causal_windowed=True,
causal_chunk_length=8, causal_chunk_length=8,
causal_window_length=3, causal_window_length=3,
causal_window_decay=causal_window_decay,
causal_padding=causal_padding) causal_padding=causal_padding)
query = tf.random.normal( query = tf.random.normal(
shape=(batch_size, seq_length, key_dim)) shape=(batch_size, seq_length, key_dim))
...@@ -175,5 +176,26 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase): ...@@ -175,5 +176,26 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
# If the serialization was successful, the new config should match the old. # If the serialization was successful, the new config should match the old.
self.assertAllEqual(test_layer.get_config(), new_layer.get_config()) self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
def test_rectangular_window_sum(self):
x = tf.ones([2, 5, 2, 2, 2])
winsum = attention.rectangular_window_sum(x, 3)
self.assertEqual(winsum.shape, x.shape)
self.assertAllClose(
tf.tile(
tf.reshape([1., 2., 3., 3., 3.], [1, -1, 1, 1, 1]),
[2, 1, 2, 2, 2]),
winsum)
def test_weighted_window_sum(self):
x = tf.ones([2, 5, 2, 2, 2])
winsum = attention.weighted_window_sum(x, 3, [0.01, 0.1, 1.])
self.assertEqual(winsum.shape, x.shape)
self.assertAllClose(
tf.tile(
tf.reshape([1., 1.1, 1.11, 1.11, 1.11], [1, -1, 1, 1, 1]),
[2, 1, 2, 2, 2]),
winsum)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment