Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4ad903b4
Commit
4ad903b4
authored
Jul 27, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 463764367
parent
93245b4f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
57 deletions
+74
-57
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+70
-53
official/nlp/modeling/layers/kernel_attention_test.py
official/nlp/modeling/layers/kernel_attention_test.py
+4
-4
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
4ad903b4
...
@@ -41,7 +41,7 @@ class KernelMask(tf.keras.layers.Layer):
...
@@ -41,7 +41,7 @@ class KernelMask(tf.keras.layers.Layer):
return
mask
return
mask
def
pad_to_chunk_length
(
tensor
,
axis
,
chunk_length
,
pad
=
"right"
):
def
pad_to_chunk_length
(
tensor
,
axis
,
chunk_length
,
pad
ding
=
None
):
"""Pads a tensor so that shape[axis] is divisible by chunk_length.
"""Pads a tensor so that shape[axis] is divisible by chunk_length.
Args:
Args:
...
@@ -49,9 +49,11 @@ def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"):
...
@@ -49,9 +49,11 @@ def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"):
axis: Axis to pad along.
axis: Axis to pad along.
chunk_length: The output tensor will have shape[axis] divisible by
chunk_length: The output tensor will have shape[axis] divisible by
chunk_length.
chunk_length.
pad: Pad the input tensor across the axis from left if pad="left", right if
padding: Pad the input tensor across the axis from either left or
pad="right", or apply no padding if pad=None. In the latter case, the axis
right if padding is set to "left" or "right"; applies no padding
dimension of the input tensor must be divisible by the chunk_length.
if padding is set to None. In the latter case, the axis
dimension of the input tensor must be divisible by the
chunk_length.
Returns:
Returns:
Padded tensor with shape[axis] divisible by chunk_length.
Padded tensor with shape[axis] divisible by chunk_length.
...
@@ -62,19 +64,23 @@ def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"):
...
@@ -62,19 +64,23 @@ def pad_to_chunk_length(tensor, axis, chunk_length, pad="right"):
axis
+=
rank
axis
+=
rank
axis_length
=
shape
[
axis
]
axis_length
=
shape
[
axis
]
pad_length
=
-
axis_length
%
chunk_length
pad_length
=
-
axis_length
%
chunk_length
if
pad
==
"right"
:
if
pad
ding
==
"right"
:
pad_width_2
=
[[
0
,
pad_length
]]
axis_paddings
=
[[
0
,
pad_length
]]
elif
pad
==
"left"
:
elif
pad
ding
==
"left"
:
pad_width_2
=
[[
pad_length
,
0
]]
axis_paddings
=
[[
pad_length
,
0
]]
el
s
e
:
el
if
padding
is
Non
e
:
if
pad_length
!=
0
:
if
pad_length
!=
0
:
raise
ValueError
(
"When padding is
not set
, the axis dimension"
raise
ValueError
(
"When padding is
None
, the axis dimension"
"has to be divisible by the chunk_length."
)
"has to be divisible by the chunk_length."
)
return
tensor
return
tensor
pad_width
=
tf
.
concat
(
else
:
[
tf
.
zeros
([
axis
,
2
],
dtype
=
tf
.
int32
),
pad_width_2
,
raise
ValueError
(
"Illegal padding value; must be one of
\"
left
\"
"
"
\"
right
\"
or None."
)
paddings
=
tf
.
concat
(
[
tf
.
zeros
([
axis
,
2
],
dtype
=
tf
.
int32
),
axis_paddings
,
tf
.
zeros
([
rank
-
axis
-
1
,
2
],
dtype
=
tf
.
int32
)],
axis
=
0
)
tf
.
zeros
([
rank
-
axis
-
1
,
2
],
dtype
=
tf
.
int32
)],
axis
=
0
)
return
tf
.
pad
(
tensor
,
pad
_width
)
return
tf
.
pad
(
tensor
,
pad
dings
)
def
split_tensor_into_chunks
(
tensor
,
axis
,
chunk_length
):
def
split_tensor_into_chunks
(
tensor
,
axis
,
chunk_length
):
...
@@ -95,12 +101,12 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
...
@@ -95,12 +101,12 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
return
tf
.
reshape
(
tensor
,
new_shape
)
return
tf
.
reshape
(
tensor
,
new_shape
)
def
windowed_
causal_
performer_attention
(
query_matrix
,
def
causal_
windowed_performer_attention
(
query_matrix
,
key_matrix
,
key_matrix
,
value_matrix
,
value_matrix
,
chunk_length
,
chunk_length
,
window_length
,
window_length
,
pad
=
"right"
):
pad
ding
=
None
):
"""Applies windowed causal kernel attention with query, key, value tensors.
"""Applies windowed causal kernel attention with query, key, value tensors.
We partition the T-length input sequence into N chunks, each of chunk_length
We partition the T-length input sequence into N chunks, each of chunk_length
...
@@ -113,19 +119,19 @@ def windowed_causal_performer_attention(query_matrix,
...
@@ -113,19 +119,19 @@ def windowed_causal_performer_attention(query_matrix,
Below is an example with T=9, chunk_length=3, window_length=1. 1 indicates
Below is an example with T=9, chunk_length=3, window_length=1. 1 indicates
attention is computed between the pair while 0 indicates attention is not
attention is computed between the pair while 0 indicates attention is not
computed between the pairs:
computed between the pairs:
111000000
111000000
111000000
111000000
111000000
111000000
111111000
111111000
111111000
111111000
111111000
111111000
000111111
000111111
000111111
000111111
000111111
000111111
User can ensure sequence_length is divisible by chunk_length or use
User can ensure sequence_length is divisible by chunk_length or use
pad="left"/"right" to pad the sequence length either at the
top or bottom
pad
ding
="left"/"right" to pad the sequence length either at the
left
respectively and make it divisible by chunk_length.
or right
respectively and make it divisible by chunk_length.
Args:
Args:
query_matrix: Kernel query `Tensor` of shape `[B, T, N, dim]`.
query_matrix: Kernel query `Tensor` of shape `[B, T, N, dim]`.
...
@@ -133,20 +139,20 @@ def windowed_causal_performer_attention(query_matrix,
...
@@ -133,20 +139,20 @@ def windowed_causal_performer_attention(query_matrix,
value_matrix: Value `Tensor` of shape `[B, T, N, out_dim]`.
value_matrix: Value `Tensor` of shape `[B, T, N, out_dim]`.
chunk_length: Length of each chunk in tokens.
chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks.
window_length: Length of attention window in chunks.
pad: Pad the query, value and key input tensors across the T dimension from
padding: Pad the query, value and key input tensors across the
left if pad="left", right if pad="right", or apply no padding if pad=None.
axis from either left or right if padding is set to "left" or
In the latter case, the T dimension of the input tensors must be divisible
"right"; apply no padding if padding is set to None. In the
by the chunk_length.
latter case, the axis dimension of the query, value and key
input tensors must be divisible by the chunk_length.
Returns:
Returns:
Window causal performer attention of shape `[B, T, N, out_dim]`.
Window causal performer attention of shape `[B, T, N, out_dim]`.
"""
"""
old_shape
=
tf
.
shape
(
value_matrix
)
old_shape
=
tf
.
shape
(
value_matrix
)
query_matrix
=
pad_to_chunk_length
(
query_matrix
,
-
3
,
chunk_length
,
pad
)
query_matrix
=
pad_to_chunk_length
(
query_matrix
,
-
3
,
chunk_length
,
pad
ding
)
key_matrix
=
pad_to_chunk_length
(
key_matrix
,
-
3
,
chunk_length
,
pad
)
key_matrix
=
pad_to_chunk_length
(
key_matrix
,
-
3
,
chunk_length
,
pad
ding
)
value_matrix
=
pad_to_chunk_length
(
value_matrix
,
-
3
,
chunk_length
,
pad
)
value_matrix
=
pad_to_chunk_length
(
value_matrix
,
-
3
,
chunk_length
,
pad
ding
)
new_shape
=
tf
.
shape
(
value_matrix
)
new_shape
=
tf
.
shape
(
value_matrix
)
chunked_query_matrix
=
split_tensor_into_chunks
(
chunked_query_matrix
=
split_tensor_into_chunks
(
...
@@ -446,16 +452,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -446,16 +452,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
begin_kernel
=
0
,
begin_kernel
=
0
,
scale
=
None
,
scale
=
None
,
scale_by_length
=
False
,
scale_by_length
=
False
,
use_windowed_causal
=
False
,
use_causal_windowed
=
False
,
chunk_length
=
1
,
causal_chunk_length
=
1
,
window_length
=
3
,
causal_window_length
=
1
,
causal_padding
=
None
,
**
kwargs
):
**
kwargs
):
r
"""Constructor of KernelAttention.
r
"""Constructor of KernelAttention.
Args:
Args:
feature_transform: A non-linear transform of the keys and quries.
Possible
feature_transform: A non-linear transform of the keys and qu
e
ries.
transforms are "elu", "relu", "square", "exp", "expplus",
"expmod",
Possible
transforms are "elu", "relu", "square", "exp", "expplus",
"identity".
"expmod",
"identity".
num_random_features: Number of random features to be used for projection.
num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform.
if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the
seed: The seed to begin drawing random features. Once the seed is set, the
...
@@ -475,11 +482,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -475,11 +482,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
the dot product based on key length. Set as log_512^(n) to stablize
the dot product based on key length. Set as log_512^(n) to stablize
attention entropy against length. Refer to
attention entropy against length. Refer to
https://kexue.fm/archives/8823 for details.
https://kexue.fm/archives/8823 for details.
use_windowed_causal: If true perform windowed causal attention. See
use_causal_windowed: If true perform windowed causal attention. See
windowed_causal_performer_attention function docstring for more details.
causal_windowed_performer_attention function docstring for more details.
chunk_length: Length of each chunk in tokens.
causal_chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks.
causal_window_length: Length of attention window in chunks.
**kwargs: The same arguments `MultiHeadAttention` layer.
causal_padding: Pad the query, value and key input tensors
across the axis from either left or right if padding is set to
"left" or "right"; apply no padding if padding is set to None.
In the latter case, the axis dimension of the query, value and
key input tensors must be divisible by the chunk_length.
**kwargs:
The same arguments `MultiHeadAttention` layer.
"""
"""
if
(
feature_transform
not
in
_TRANSFORM_MAP
and
if
(
feature_transform
not
in
_TRANSFORM_MAP
and
feature_transform
!=
"expplus"
):
feature_transform
!=
"expplus"
):
...
@@ -509,12 +522,13 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -509,12 +522,13 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self
.
_projection_matrix
=
create_projection_matrix
(
self
.
_projection_matrix
=
create_projection_matrix
(
self
.
_num_random_features
,
self
.
_key_dim
,
self
.
_num_random_features
,
self
.
_key_dim
,
tf
.
constant
([
self
.
_seed
,
self
.
_seed
+
1
]))
tf
.
constant
([
self
.
_seed
,
self
.
_seed
+
1
]))
self
.
use_windowed_causal
=
use_windowed_causal
self
.
use_causal_windowed
=
use_causal_windowed
self
.
chunk_length
=
chunk_length
self
.
causal_chunk_length
=
causal_chunk_length
self
.
window_length
=
window_length
self
.
causal_window_length
=
causal_window_length
if
self
.
use_windowed_causal
and
self
.
_is_short_seq
:
self
.
causal_padding
=
causal_padding
if
self
.
use_causal_windowed
and
self
.
_is_short_seq
:
raise
ValueError
(
raise
ValueError
(
"use_windowed
_causal
and short_seq methods are mutually exclusive"
)
"use_
causal_
windowed and short_seq methods are mutually exclusive"
)
def
_compute_attention
(
self
,
def
_compute_attention
(
self
,
query
,
query
,
...
@@ -590,9 +604,12 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -590,9 +604,12 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_scores
=
tf
.
einsum
(
"BTNH,BSNH->BTSN"
,
query_prime
,
key_prime
)
attention_scores
=
tf
.
einsum
(
"BTNH,BSNH->BTSN"
,
query_prime
,
key_prime
)
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
axis
=
2
)
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
axis
=
2
)
attention_output
=
tf
.
einsum
(
"BTSN,BSNH->BTNH"
,
attention_scores
,
value
)
attention_output
=
tf
.
einsum
(
"BTSN,BSNH->BTNH"
,
attention_scores
,
value
)
elif
self
.
use_windowed_causal
:
elif
self
.
use_causal_windowed
:
attention_output
=
windowed_causal_performer_attention
(
attention_output
=
causal_windowed_performer_attention
(
query_prime
,
key_prime
,
value
,
self
.
chunk_length
,
self
.
window_length
)
query_prime
,
key_prime
,
value
,
chunk_length
=
self
.
causal_chunk_length
,
window_length
=
self
.
causal_window_length
,
padding
=
self
.
causal_padding
)
else
:
else
:
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key_prime
,
value
)
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key_prime
,
value
)
denominator
=
1.0
/
(
denominator
=
1.0
/
(
...
...
official/nlp/modeling/layers/kernel_attention_test.py
View file @
4ad903b4
...
@@ -63,7 +63,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -63,7 +63,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
[
0
]))
[
0
]))
def
test_windowed_
causal_
attention_projection
(
def
test_
causal_
windowed_attention_projection
(
self
,
feature_transform
,
num_random_features
,
training
,
redraw
,
self
,
feature_transform
,
num_random_features
,
training
,
redraw
,
begin_kernel
):
begin_kernel
):
num_heads
=
12
num_heads
=
12
...
@@ -78,9 +78,9 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -78,9 +78,9 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
redraw
=
redraw
,
redraw
=
redraw
,
is_short_seq
=
False
,
is_short_seq
=
False
,
begin_kernel
=
begin_kernel
,
begin_kernel
=
begin_kernel
,
use_windowed
_causal
=
True
,
use_
causal_
windowed
=
True
,
chunk_length
=
8
,
causal_
chunk_length
=
8
,
window_length
=
3
)
causal_
window_length
=
3
)
query
=
tf
.
random
.
normal
(
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
shape
=
(
batch_size
,
seq_length
,
key_dim
))
value
=
query
value
=
query
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment