Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1f166323
Commit
1f166323
authored
Aug 03, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 465060096
parent
b1844216
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
114 additions
and
19 deletions
+114
-19
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+89
-16
official/nlp/modeling/layers/kernel_attention_test.py
official/nlp/modeling/layers/kernel_attention_test.py
+25
-3
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
1f166323
...
@@ -98,11 +98,69 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
...
@@ -98,11 +98,69 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
return
tf
.
reshape
(
tensor
,
new_shape
)
return
tf
.
reshape
(
tensor
,
new_shape
)
def
rectangular_window_sum
(
tensor
,
window_length
):
"""Summarizes tensor elements over a sliding rectangular window.
Sums elements of the input tensor of shape [B, T', C', H, dim]
across a rectangular window sliding along the dimension T'.
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the rectangular window.
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
window.
"""
tensor_cumsum
=
tf
.
cumsum
(
tensor
,
axis
=-
4
)
tensor_winsum
=
tensor_cumsum
-
tf
.
pad
(
tensor_cumsum
,
[[
0
,
0
],
[
window_length
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]])[:,
:
-
window_length
]
return
tensor_winsum
def
weighted_window_sum
(
tensor
,
window_length
,
window_weights
):
"""Summarizes tensor elements over a sliding weighted window.
Computes a weighted sum of elements of the input tensor of shape [B,
T', C', H, dim] across a window sliding along the dimension T'.
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the window.
window_weights: Tensor of shape [window_length] containing window
weights.
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
window.
"""
# Flatten the last three dimensions of the [B, T', C', H, dim] shape
# into a single channels dimension.
tensor_shape
=
tf
.
shape
(
tensor
)
tensor_2d
=
tf
.
reshape
(
tensor
,
[
tensor_shape
[
0
],
tensor_shape
[
1
],
1
,
-
1
])
# Apply the same weights to all channels.
conv_filter
=
tf
.
tile
(
tf
.
reshape
(
window_weights
,
[
-
1
,
1
,
1
,
1
]),
multiples
=
[
1
,
1
,
tf
.
shape
(
tensor_2d
)[
-
1
],
1
])
tensor_winsum_2d
=
tf
.
nn
.
depthwise_conv2d
(
tensor_2d
,
conv_filter
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
[[
0
,
0
],
[
window_length
-
1
,
0
],
[
0
,
0
],
[
0
,
0
]])
# Unflatten the channels dimension into the original shape.
tensor_winsum
=
tf
.
reshape
(
tensor_winsum_2d
,
tensor_shape
)
return
tensor_winsum
def
causal_windowed_performer_attention
(
query_matrix
,
def
causal_windowed_performer_attention
(
query_matrix
,
key_matrix
,
key_matrix
,
value_matrix
,
value_matrix
,
chunk_length
,
chunk_length
,
window_length
,
window_length
,
window_decay
=
None
,
padding
=
None
):
padding
=
None
):
"""Applies windowed causal kernel attention with query, key, value tensors.
"""Applies windowed causal kernel attention with query, key, value tensors.
...
@@ -133,11 +191,14 @@ def causal_windowed_performer_attention(query_matrix,
...
@@ -133,11 +191,14 @@ def causal_windowed_performer_attention(query_matrix,
or right respectively and make it divisible by chunk_length.
or right respectively and make it divisible by chunk_length.
Args:
Args:
query_matrix: Kernel query `Tensor` of shape `[B, T,
N
, dim]`.
query_matrix: Kernel query `Tensor` of shape `[B, T,
H
, dim]`.
key_matrix: Kernel key `Tensor` of shape `[B, T,
N
, dim]`.
key_matrix: Kernel key `Tensor` of shape `[B, T,
H
, dim]`.
value_matrix: Value `Tensor` of shape `[B, T,
N
, out_dim]`.
value_matrix: Value `Tensor` of shape `[B, T,
H
, out_dim]`.
chunk_length: Length of each chunk in tokens.
chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks.
window_length: Length of attention window in chunks.
window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this factor
before summation.
padding: Pad the query, value and key input tensors across the
padding: Pad the query, value and key input tensors across the
axis from either left or right if padding is set to "left" or
axis from either left or right if padding is set to "left" or
"right"; apply no padding if padding is set to None. In the
"right"; apply no padding if padding is set to None. In the
...
@@ -145,7 +206,7 @@ def causal_windowed_performer_attention(query_matrix,
...
@@ -145,7 +206,7 @@ def causal_windowed_performer_attention(query_matrix,
input tensors must be divisible by the chunk_length.
input tensors must be divisible by the chunk_length.
Returns:
Returns:
Window causal performer attention of shape `[B, T,
N
, out_dim]`.
Window causal performer attention of shape `[B, T,
H
, out_dim]`.
"""
"""
old_shape
=
tf
.
shape
(
value_matrix
)
old_shape
=
tf
.
shape
(
value_matrix
)
...
@@ -164,19 +225,26 @@ def causal_windowed_performer_attention(query_matrix,
...
@@ -164,19 +225,26 @@ def causal_windowed_performer_attention(query_matrix,
value_matrix
,
-
3
,
value_matrix
,
-
3
,
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, out_dim]
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, out_dim]
kp_v
=
tf
.
einsum
(
"B
N
CHD,B
N
CHO->B
N
HDO"
,
chunked_key_matrix
,
kp_v
=
tf
.
einsum
(
"B
T
CHD,B
T
CHO->B
T
HDO"
,
chunked_key_matrix
,
chunked_value_matrix
)
chunked_value_matrix
)
kp_v_cumsum
=
tf
.
cumsum
(
kp_v
,
axis
=-
4
)
kp_v_winsum
=
kp_v_cumsum
-
tf
.
pad
(
kp_v_cumsum
,
[[
0
,
0
],
[
window_length
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]])[:,
:
-
window_length
]
numerator
=
tf
.
einsum
(
"BNCHD,BNHDO->BNCHO"
,
chunked_query_matrix
,
kp_v_winsum
)
k_sum
=
tf
.
reduce_sum
(
chunked_key_matrix
,
axis
=-
3
)
k_sum
=
tf
.
math
.
reduce_sum
(
chunked_key_matrix
,
axis
=-
3
,
keepdims
=
True
)
k_cumsum
=
tf
.
cumsum
(
k_sum
,
axis
=-
3
)
k_winsum
=
k_cumsum
-
tf
.
pad
(
k_cumsum
,
[[
0
,
0
],
[
window_length
,
0
],
[
0
,
0
],
if
window_decay
is
None
:
[
0
,
0
]])[:,
:
-
window_length
]
kp_v_winsum
=
rectangular_window_sum
(
kp_v
,
window_length
)
denominator
=
tf
.
einsum
(
"BNCHD,BNHD->BNCH"
,
chunked_query_matrix
,
k_winsum
)
k_winsum
=
rectangular_window_sum
(
k_sum
,
window_length
)
else
:
# Compute exponentially decaying weights.
decaying_weights
=
tf
.
math
.
pow
(
tf
.
convert_to_tensor
(
window_decay
,
dtype
=
value_matrix
.
dtype
),
tf
.
range
(
window_length
-
1
,
-
1
,
delta
=-
1
,
dtype
=
value_matrix
.
dtype
))
kp_v_winsum
=
weighted_window_sum
(
kp_v
,
window_length
,
decaying_weights
)
k_winsum
=
weighted_window_sum
(
k_sum
,
window_length
,
decaying_weights
)
numerator
=
tf
.
einsum
(
"BTCHD,BTHDO->BTCHO"
,
chunked_query_matrix
,
kp_v_winsum
)
k_winsum
=
tf
.
squeeze
(
k_winsum
,
-
3
)
denominator
=
tf
.
einsum
(
"BTCHD,BTHD->BTCH"
,
chunked_query_matrix
,
k_winsum
)
denominator
=
tf
.
expand_dims
(
denominator
,
-
1
)
+
_NUMERIC_STABLER
denominator
=
tf
.
expand_dims
(
denominator
,
-
1
)
+
_NUMERIC_STABLER
attention
=
numerator
/
denominator
attention
=
numerator
/
denominator
...
@@ -351,7 +419,6 @@ def expplus(data_orig,
...
@@ -351,7 +419,6 @@ def expplus(data_orig,
diag_omega
=
tf
.
expand_dims
(
diag_omega
,
axis
=
0
)
diag_omega
=
tf
.
expand_dims
(
diag_omega
,
axis
=
0
)
diag_omega
=
tf
.
expand_dims
(
diag_omega
,
axis
=
0
)
diag_omega
=
tf
.
expand_dims
(
diag_omega
,
axis
=
0
)
diag_omega
=
a_coeff
*
diag_omega
diag_omega
=
a_coeff
*
diag_omega
#
if
numerical_renormalizer
:
if
numerical_renormalizer
:
if
is_query
:
if
is_query
:
...
@@ -454,6 +521,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -454,6 +521,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
use_causal_windowed
=
False
,
use_causal_windowed
=
False
,
causal_chunk_length
=
1
,
causal_chunk_length
=
1
,
causal_window_length
=
3
,
causal_window_length
=
3
,
causal_window_decay
=
None
,
causal_padding
=
None
,
causal_padding
=
None
,
**
kwargs
):
**
kwargs
):
r
"""Constructor of KernelAttention.
r
"""Constructor of KernelAttention.
...
@@ -485,6 +553,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -485,6 +553,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
causal_windowed_performer_attention function docstring for more details.
causal_windowed_performer_attention function docstring for more details.
causal_chunk_length: Length of each chunk in tokens.
causal_chunk_length: Length of each chunk in tokens.
causal_window_length: Length of attention window in chunks.
causal_window_length: Length of attention window in chunks.
causal_window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this
factor before summation.
causal_padding: Pad the query, value and key input tensors
causal_padding: Pad the query, value and key input tensors
across the axis from either left or right if padding is set to
across the axis from either left or right if padding is set to
"left" or "right"; apply no padding if padding is set to None.
"left" or "right"; apply no padding if padding is set to None.
...
@@ -524,6 +595,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -524,6 +595,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self
.
use_causal_windowed
=
use_causal_windowed
self
.
use_causal_windowed
=
use_causal_windowed
self
.
causal_chunk_length
=
causal_chunk_length
self
.
causal_chunk_length
=
causal_chunk_length
self
.
causal_window_length
=
causal_window_length
self
.
causal_window_length
=
causal_window_length
self
.
causal_window_decay
=
causal_window_decay
self
.
causal_padding
=
causal_padding
self
.
causal_padding
=
causal_padding
if
self
.
use_causal_windowed
and
self
.
_is_short_seq
:
if
self
.
use_causal_windowed
and
self
.
_is_short_seq
:
raise
ValueError
(
raise
ValueError
(
...
@@ -608,6 +680,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -608,6 +680,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
query_prime
,
key_prime
,
value
,
query_prime
,
key_prime
,
value
,
chunk_length
=
self
.
causal_chunk_length
,
chunk_length
=
self
.
causal_chunk_length
,
window_length
=
self
.
causal_window_length
,
window_length
=
self
.
causal_window_length
,
window_decay
=
self
.
causal_window_decay
,
padding
=
self
.
causal_padding
)
padding
=
self
.
causal_padding
)
else
:
else
:
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key_prime
,
value
)
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key_prime
,
value
)
...
...
official/nlp/modeling/layers/kernel_attention_test.py
View file @
1f166323
...
@@ -61,11 +61,11 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -61,11 +61,11 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
itertools
.
product
(
[
"relu"
,
"exp"
]
,
[
127
],
_TRAINING
,
[
True
,
False
],
[
0
],
[
None
,
"left"
,
"right"
]))
[
0
],
[
None
,
0.97
],
[
None
,
"left"
,
"right"
]))
def
test_causal_windowed_attention_projection
(
def
test_causal_windowed_attention_projection
(
self
,
feature_transform
,
num_random_features
,
training
,
redraw
,
self
,
feature_transform
,
num_random_features
,
training
,
redraw
,
begin_kernel
,
causal_padding
):
begin_kernel
,
causal_window_decay
,
causal_padding
):
num_heads
=
12
num_heads
=
12
key_dim
=
64
key_dim
=
64
seq_length
=
1024
seq_length
=
1024
...
@@ -81,6 +81,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -81,6 +81,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
use_causal_windowed
=
True
,
use_causal_windowed
=
True
,
causal_chunk_length
=
8
,
causal_chunk_length
=
8
,
causal_window_length
=
3
,
causal_window_length
=
3
,
causal_window_decay
=
causal_window_decay
,
causal_padding
=
causal_padding
)
causal_padding
=
causal_padding
)
query
=
tf
.
random
.
normal
(
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
shape
=
(
batch_size
,
seq_length
,
key_dim
))
...
@@ -175,5 +176,26 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -175,5 +176,26 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
# If the serialization was successful, the new config should match the old.
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
test_layer
.
get_config
(),
new_layer
.
get_config
())
self
.
assertAllEqual
(
test_layer
.
get_config
(),
new_layer
.
get_config
())
def
test_rectangular_window_sum
(
self
):
x
=
tf
.
ones
([
2
,
5
,
2
,
2
,
2
])
winsum
=
attention
.
rectangular_window_sum
(
x
,
3
)
self
.
assertEqual
(
winsum
.
shape
,
x
.
shape
)
self
.
assertAllClose
(
tf
.
tile
(
tf
.
reshape
([
1.
,
2.
,
3.
,
3.
,
3.
],
[
1
,
-
1
,
1
,
1
,
1
]),
[
2
,
1
,
2
,
2
,
2
]),
winsum
)
def
test_weighted_window_sum
(
self
):
x
=
tf
.
ones
([
2
,
5
,
2
,
2
,
2
])
winsum
=
attention
.
weighted_window_sum
(
x
,
3
,
[
0.01
,
0.1
,
1.
])
self
.
assertEqual
(
winsum
.
shape
,
x
.
shape
)
self
.
assertAllClose
(
tf
.
tile
(
tf
.
reshape
([
1.
,
1.1
,
1.11
,
1.11
,
1.11
],
[
1
,
-
1
,
1
,
1
,
1
]),
[
2
,
1
,
2
,
2
,
2
]),
winsum
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment