Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1f166323
Commit
1f166323
authored
Aug 03, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 465060096
parent
b1844216
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
114 additions
and
19 deletions
+114
-19
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+89
-16
official/nlp/modeling/layers/kernel_attention_test.py
official/nlp/modeling/layers/kernel_attention_test.py
+25
-3
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
1f166323
...
...
@@ -98,11 +98,69 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
return
tf
.
reshape
(
tensor
,
new_shape
)
def
rectangular_window_sum
(
tensor
,
window_length
):
"""Summarizes tensor elements over a sliding rectangular window.
Sums elements of the input tensor of shape [B, T', C', H, dim]
across a rectangular window sliding along the dimension T'.
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the rectangular window.
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
window.
"""
tensor_cumsum
=
tf
.
cumsum
(
tensor
,
axis
=-
4
)
tensor_winsum
=
tensor_cumsum
-
tf
.
pad
(
tensor_cumsum
,
[[
0
,
0
],
[
window_length
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]])[:,
:
-
window_length
]
return
tensor_winsum
def
weighted_window_sum
(
tensor
,
window_length
,
window_weights
):
"""Summarizes tensor elements over a sliding weighted window.
Computes a weighted sum of elements of the input tensor of shape [B,
T', C', H, dim] across a window sliding along the dimension T'.
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the window.
window_weights: Tensor of shape [window_length] containing window
weights.
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
window.
"""
# Flatten the last three dimensions of the [B, T', C', H, dim] shape
# into a single channels dimension.
tensor_shape
=
tf
.
shape
(
tensor
)
tensor_2d
=
tf
.
reshape
(
tensor
,
[
tensor_shape
[
0
],
tensor_shape
[
1
],
1
,
-
1
])
# Apply the same weights to all channels.
conv_filter
=
tf
.
tile
(
tf
.
reshape
(
window_weights
,
[
-
1
,
1
,
1
,
1
]),
multiples
=
[
1
,
1
,
tf
.
shape
(
tensor_2d
)[
-
1
],
1
])
tensor_winsum_2d
=
tf
.
nn
.
depthwise_conv2d
(
tensor_2d
,
conv_filter
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
[[
0
,
0
],
[
window_length
-
1
,
0
],
[
0
,
0
],
[
0
,
0
]])
# Unflatten the channels dimension into the original shape.
tensor_winsum
=
tf
.
reshape
(
tensor_winsum_2d
,
tensor_shape
)
return
tensor_winsum
def
causal_windowed_performer_attention
(
query_matrix
,
key_matrix
,
value_matrix
,
chunk_length
,
window_length
,
window_decay
=
None
,
padding
=
None
):
"""Applies windowed causal kernel attention with query, key, value tensors.
...
...
@@ -133,11 +191,14 @@ def causal_windowed_performer_attention(query_matrix,
or right respectively and make it divisible by chunk_length.
Args:
query_matrix: Kernel query `Tensor` of shape `[B, T,
N
, dim]`.
key_matrix: Kernel key `Tensor` of shape `[B, T,
N
, dim]`.
value_matrix: Value `Tensor` of shape `[B, T,
N
, out_dim]`.
query_matrix: Kernel query `Tensor` of shape `[B, T,
H
, dim]`.
key_matrix: Kernel key `Tensor` of shape `[B, T,
H
, dim]`.
value_matrix: Value `Tensor` of shape `[B, T,
H
, out_dim]`.
chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks.
window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this factor
before summation.
padding: Pad the query, value and key input tensors across the
axis from either left or right if padding is set to "left" or
"right"; apply no padding if padding is set to None. In the
...
...
@@ -145,7 +206,7 @@ def causal_windowed_performer_attention(query_matrix,
input tensors must be divisible by the chunk_length.
Returns:
Window causal performer attention of shape `[B, T,
N
, out_dim]`.
Window causal performer attention of shape `[B, T,
H
, out_dim]`.
"""
old_shape
=
tf
.
shape
(
value_matrix
)
...
...
@@ -164,19 +225,26 @@ def causal_windowed_performer_attention(query_matrix,
value_matrix
,
-
3
,
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, out_dim]
kp_v
=
tf
.
einsum
(
"B
N
CHD,B
N
CHO->B
N
HDO"
,
chunked_key_matrix
,
kp_v
=
tf
.
einsum
(
"B
T
CHD,B
T
CHO->B
T
HDO"
,
chunked_key_matrix
,
chunked_value_matrix
)
kp_v_cumsum
=
tf
.
cumsum
(
kp_v
,
axis
=-
4
)
kp_v_winsum
=
kp_v_cumsum
-
tf
.
pad
(
kp_v_cumsum
,
[[
0
,
0
],
[
window_length
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]])[:,
:
-
window_length
]
numerator
=
tf
.
einsum
(
"BNCHD,BNHDO->BNCHO"
,
chunked_query_matrix
,
kp_v_winsum
)
k_sum
=
tf
.
reduce_sum
(
chunked_key_matrix
,
axis
=-
3
)
k_cumsum
=
tf
.
cumsum
(
k_sum
,
axis
=-
3
)
k_winsum
=
k_cumsum
-
tf
.
pad
(
k_cumsum
,
[[
0
,
0
],
[
window_length
,
0
],
[
0
,
0
],
[
0
,
0
]])[:,
:
-
window_length
]
denominator
=
tf
.
einsum
(
"BNCHD,BNHD->BNCH"
,
chunked_query_matrix
,
k_winsum
)
k_sum
=
tf
.
math
.
reduce_sum
(
chunked_key_matrix
,
axis
=-
3
,
keepdims
=
True
)
if
window_decay
is
None
:
kp_v_winsum
=
rectangular_window_sum
(
kp_v
,
window_length
)
k_winsum
=
rectangular_window_sum
(
k_sum
,
window_length
)
else
:
# Compute exponentially decaying weights.
decaying_weights
=
tf
.
math
.
pow
(
tf
.
convert_to_tensor
(
window_decay
,
dtype
=
value_matrix
.
dtype
),
tf
.
range
(
window_length
-
1
,
-
1
,
delta
=-
1
,
dtype
=
value_matrix
.
dtype
))
kp_v_winsum
=
weighted_window_sum
(
kp_v
,
window_length
,
decaying_weights
)
k_winsum
=
weighted_window_sum
(
k_sum
,
window_length
,
decaying_weights
)
numerator
=
tf
.
einsum
(
"BTCHD,BTHDO->BTCHO"
,
chunked_query_matrix
,
kp_v_winsum
)
k_winsum
=
tf
.
squeeze
(
k_winsum
,
-
3
)
denominator
=
tf
.
einsum
(
"BTCHD,BTHD->BTCH"
,
chunked_query_matrix
,
k_winsum
)
denominator
=
tf
.
expand_dims
(
denominator
,
-
1
)
+
_NUMERIC_STABLER
attention
=
numerator
/
denominator
...
...
@@ -351,7 +419,6 @@ def expplus(data_orig,
diag_omega
=
tf
.
expand_dims
(
diag_omega
,
axis
=
0
)
diag_omega
=
tf
.
expand_dims
(
diag_omega
,
axis
=
0
)
diag_omega
=
a_coeff
*
diag_omega
#
if
numerical_renormalizer
:
if
is_query
:
...
...
@@ -454,6 +521,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
use_causal_windowed
=
False
,
causal_chunk_length
=
1
,
causal_window_length
=
3
,
causal_window_decay
=
None
,
causal_padding
=
None
,
**
kwargs
):
r
"""Constructor of KernelAttention.
...
...
@@ -485,6 +553,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
causal_windowed_performer_attention function docstring for more details.
causal_chunk_length: Length of each chunk in tokens.
causal_window_length: Length of attention window in chunks.
causal_window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this
factor before summation.
causal_padding: Pad the query, value and key input tensors
across the axis from either left or right if padding is set to
"left" or "right"; apply no padding if padding is set to None.
...
...
@@ -524,6 +595,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self
.
use_causal_windowed
=
use_causal_windowed
self
.
causal_chunk_length
=
causal_chunk_length
self
.
causal_window_length
=
causal_window_length
self
.
causal_window_decay
=
causal_window_decay
self
.
causal_padding
=
causal_padding
if
self
.
use_causal_windowed
and
self
.
_is_short_seq
:
raise
ValueError
(
...
...
@@ -608,6 +680,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
query_prime
,
key_prime
,
value
,
chunk_length
=
self
.
causal_chunk_length
,
window_length
=
self
.
causal_window_length
,
window_decay
=
self
.
causal_window_decay
,
padding
=
self
.
causal_padding
)
else
:
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key_prime
,
value
)
...
...
official/nlp/modeling/layers/kernel_attention_test.py
View file @
1f166323
...
...
@@ -61,11 +61,11 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
[
0
],
[
None
,
"left"
,
"right"
]))
itertools
.
product
(
[
"relu"
,
"exp"
]
,
[
127
],
_TRAINING
,
[
True
,
False
],
[
0
],
[
None
,
0.97
],
[
None
,
"left"
,
"right"
]))
def
test_causal_windowed_attention_projection
(
self
,
feature_transform
,
num_random_features
,
training
,
redraw
,
begin_kernel
,
causal_padding
):
begin_kernel
,
causal_window_decay
,
causal_padding
):
num_heads
=
12
key_dim
=
64
seq_length
=
1024
...
...
@@ -81,6 +81,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
use_causal_windowed
=
True
,
causal_chunk_length
=
8
,
causal_window_length
=
3
,
causal_window_decay
=
causal_window_decay
,
causal_padding
=
causal_padding
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
...
...
@@ -175,5 +176,26 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
test_layer
.
get_config
(),
new_layer
.
get_config
())
def
test_rectangular_window_sum
(
self
):
x
=
tf
.
ones
([
2
,
5
,
2
,
2
,
2
])
winsum
=
attention
.
rectangular_window_sum
(
x
,
3
)
self
.
assertEqual
(
winsum
.
shape
,
x
.
shape
)
self
.
assertAllClose
(
tf
.
tile
(
tf
.
reshape
([
1.
,
2.
,
3.
,
3.
,
3.
],
[
1
,
-
1
,
1
,
1
,
1
]),
[
2
,
1
,
2
,
2
,
2
]),
winsum
)
def
test_weighted_window_sum
(
self
):
x
=
tf
.
ones
([
2
,
5
,
2
,
2
,
2
])
winsum
=
attention
.
weighted_window_sum
(
x
,
3
,
[
0.01
,
0.1
,
1.
])
self
.
assertEqual
(
winsum
.
shape
,
x
.
shape
)
self
.
assertAllClose
(
tf
.
tile
(
tf
.
reshape
([
1.
,
1.1
,
1.11
,
1.11
,
1.11
],
[
1
,
-
1
,
1
,
1
,
1
]),
[
2
,
1
,
2
,
2
,
2
]),
winsum
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment