Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a9a3bac9
Commit
a9a3bac9
authored
Jul 29, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 464067452
parent
8843bb24
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
27 deletions
+27
-27
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+16
-17
official/nlp/modeling/layers/kernel_attention_test.py
official/nlp/modeling/layers/kernel_attention_test.py
+11
-10
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
a9a3bac9
...
...
@@ -58,6 +58,8 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
Returns:
Padded tensor with shape[axis] divisible by chunk_length.
"""
if
padding
is
None
:
return
tensor
shape
=
tf
.
shape
(
tensor
)
rank
=
tf
.
rank
(
tensor
)
if
axis
<
0
:
...
...
@@ -68,14 +70,9 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
axis_paddings
=
[[
0
,
pad_length
]]
elif
padding
==
"left"
:
axis_paddings
=
[[
pad_length
,
0
]]
elif
padding
is
None
:
if
pad_length
!=
0
:
raise
ValueError
(
"When padding is None, the axis dimension"
"has to be divisible by the chunk_length."
)
return
tensor
else
:
raise
ValueError
(
"Illegal padding value; must be one of
\"
left
\"
"
"
\"
right
\"
or None."
)
raise
ValueError
(
"Illegal padding value; must be one of
\"
left
\"
,
\"
right
\"
or None."
)
paddings
=
tf
.
concat
(
[
tf
.
zeros
([
axis
,
2
],
dtype
=
tf
.
int32
),
axis_paddings
,
...
...
@@ -109,16 +106,18 @@ def causal_windowed_performer_attention(query_matrix,
padding
=
None
):
"""Applies windowed causal kernel attention with query, key, value tensors.
We partition the T-length input sequence into N chunks, each of chunk_length
tokens (thus: T = N * chunk_length). Within each chunk, we apply bidirectional
(non-causal) Performers’ implicit attention and we model relationships between
different chunks using Performers’ causal attention. We consider windowed
causal variant of performer, where the current chunk attends only to the
window of window_length of the most recent chunks.
We partition the T-length input sequence into N chunks, each of
chunk_length tokens (thus: T = N * chunk_length). Within each chunk,
we apply bidirectional (non-causal) Performers’ implicit attention
and we model relationships between different chunks using
Performers’ causal attention. We consider windowed causal variant of
performer, where the current chunk attends only to the window of
window_length of the most recent chunks.
Below is an example with T=9, chunk_length=3, window_length=2. In
this example 1 indicates attention is computed between the pair
while 0 indicates attention is not computed between the pairs:
Below is an example with T=9, chunk_length=3, window_length=1. 1 indicates
attention is computed between the pair while 0 indicates attention is not
computed between the pairs:
111000000
111000000
111000000
...
...
@@ -454,7 +453,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
scale_by_length
=
False
,
use_causal_windowed
=
False
,
causal_chunk_length
=
1
,
causal_window_length
=
1
,
causal_window_length
=
3
,
causal_padding
=
None
,
**
kwargs
):
r
"""Constructor of KernelAttention.
...
...
official/nlp/modeling/layers/kernel_attention_test.py
View file @
a9a3bac9
...
...
@@ -21,7 +21,7 @@ import tensorflow as tf
from
official.nlp.modeling.layers
import
kernel_attention
as
attention
_FEATURE_TRANSFORM
=
[
'
relu
'
,
'
elu
'
,
'
exp
'
,
'
expplus
'
]
_FEATURE_TRANSFORM
=
[
"
relu
"
,
"
elu
"
,
"
exp
"
,
"
expplus
"
]
_REDRAW
=
[
True
,
False
]
_TRAINING
=
[
True
,
False
]
_IS_SHORT_SEQ
=
[
True
,
False
]
...
...
@@ -62,10 +62,10 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
[
0
]))
[
0
]
,
[
None
,
"left"
,
"right"
]
))
def
test_causal_windowed_attention_projection
(
self
,
feature_transform
,
num_random_features
,
training
,
redraw
,
begin_kernel
):
begin_kernel
,
causal_padding
):
num_heads
=
12
key_dim
=
64
seq_length
=
1024
...
...
@@ -80,7 +80,8 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
begin_kernel
=
begin_kernel
,
use_causal_windowed
=
True
,
causal_chunk_length
=
8
,
causal_window_length
=
3
)
causal_window_length
=
3
,
causal_padding
=
causal_padding
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
value
=
query
...
...
@@ -150,14 +151,14 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertNotAllClose
(
output_scale_by_length
,
output_no_scale_by_length
)
def
test_unsupported_feature_transform
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
'
Unsupported feature_transform.*
'
):
_
=
attention
.
KernelAttention
(
feature_transform
=
'
test
'
)
with
self
.
assertRaisesRegex
(
ValueError
,
"
Unsupported feature_transform.*
"
):
_
=
attention
.
KernelAttention
(
feature_transform
=
"
test
"
)
def
test_redraw_true_no_projection
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
'
There is nothing to redraw when num_random_features.*
'
):
ValueError
,
"
There is nothing to redraw when num_random_features.*
"
):
_
=
attention
.
KernelAttention
(
num_heads
=
2
,
key_dim
=
64
,
feature_transform
=
'
elu
'
,
num_heads
=
2
,
key_dim
=
64
,
feature_transform
=
"
elu
"
,
num_random_features
=
0
,
redraw
=
True
)
def
test_config
(
self
):
...
...
@@ -166,7 +167,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
test_layer
=
attention
.
KernelAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
feature_transform
=
'
exp
'
,
feature_transform
=
"
exp
"
,
num_random_features
=
128
,
is_short_seq
=
True
)
new_layer
=
attention
.
KernelAttention
.
from_config
(
...
...
@@ -174,5 +175,5 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
test_layer
.
get_config
(),
new_layer
.
get_config
())
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment