Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d51cc280
Commit
d51cc280
authored
Apr 12, 2022
by
Jialu Liu
Committed by
A. Unique TensorFlower
Apr 12, 2022
Browse files
Internal change
PiperOrigin-RevId: 441337748
parent
54659689
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
20 deletions
+58
-20
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+29
-17
official/nlp/modeling/layers/kernel_attention_test.py
official/nlp/modeling/layers/kernel_attention_test.py
+29
-3
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
d51cc280
...
@@ -178,13 +178,13 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -178,13 +178,13 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
is_short_seq
=
False
,
is_short_seq
=
False
,
begin_kernel
=
0
,
begin_kernel
=
0
,
scale
=
None
,
scale
=
None
,
scale_by_length
=
False
,
**
kwargs
):
**
kwargs
):
r
"""Constructor of KernelAttention.
r
"""Constructor of KernelAttention.
Args:
Args:
feature_transform: A non-linear transform of the keys and quries.
feature_transform: A non-linear transform of the keys and quries. Possible
Possible transforms are "elu", "relu", "square", "exp", "expmod",
transforms are "elu", "relu", "square", "exp", "expmod", "identity".
"identity".
num_random_features: Number of random features to be used for projection.
num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform.
if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the
seed: The seed to begin drawing random features. Once the seed is set, the
...
@@ -194,12 +194,16 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -194,12 +194,16 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
redraw: Whether to redraw projection every forward pass during training.
redraw: Whether to redraw projection every forward pass during training.
The argument is only effective when num_random_features > 0.
The argument is only effective when num_random_features > 0.
is_short_seq: boolean predicate indicating whether input data consists of
is_short_seq: boolean predicate indicating whether input data consists of
very short sequences or not; in most cases this should be False
very short sequences or not; in most cases this should be False
(default
(default
option).
option).
begin_kernel: Apply kernel_attention after this sequence id and apply
begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this.
softmax attention before this.
scale: The value to scale the dot product as described in `Attention Is
scale: The value to scale the dot product as described in `Attention Is
All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
scale_by_length: boolean predicate indicating whether additionally scale
the dot product based on key length. Set as log_512^(n) to stablize
attention entropy against length. Refer to
https://kexue.fm/archives/8823 for details.
**kwargs: The same arguments `MultiHeadAttention` layer.
**kwargs: The same arguments `MultiHeadAttention` layer.
"""
"""
if
feature_transform
not
in
_TRANSFORM_MAP
:
if
feature_transform
not
in
_TRANSFORM_MAP
:
...
@@ -214,6 +218,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -214,6 +218,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self
.
_redraw
=
redraw
self
.
_redraw
=
redraw
self
.
_is_short_seq
=
is_short_seq
self
.
_is_short_seq
=
is_short_seq
self
.
_begin_kernel
=
begin_kernel
self
.
_begin_kernel
=
begin_kernel
self
.
_scale_by_length
=
scale_by_length
# We use the seed for two scenarios:
# We use the seed for two scenarios:
# 1. inference
# 1. inference
# 2. no redraw
# 2. no redraw
...
@@ -252,9 +257,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -252,9 +257,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
is_short_seq: boolean predicate indicating whether input data consists of
is_short_seq: boolean predicate indicating whether input data consists of
short or long sequences; usually short sequence is defined as having
short or long sequences; usually short sequence is defined as having
length L <= 1024.
length L <= 1024.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting
attenting
to masked positions. Note that the mask is only appied to
to masked positions. Note that the mask is only appied to
the keys. User
the keys. User
may want to mask the output if query contains pads.
may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
training mode (adding dropout) or in inference mode (doing nothing).
numeric_stabler: A scalar value added to avoid divide by 0.
numeric_stabler: A scalar value added to avoid divide by 0.
...
@@ -270,17 +275,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -270,17 +275,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else
:
else
:
projection_matrix
=
self
.
_projection_matrix
projection_matrix
=
self
.
_projection_matrix
if
self
.
_scale_by_length
:
scale
=
tf
.
math
.
log
(
tf
.
reduce_sum
(
attention_mask
,
axis
=-
1
))
*
self
.
_scale
/
math
.
log
(
512
)
scale
=
tf
.
reshape
(
scale
,
[
-
1
,
1
,
1
,
1
])
else
:
scale
=
self
.
_scale
if
is_short_seq
:
if
is_short_seq
:
# Note: Applying scalar multiply at the smaller end of einsum improves
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
# the Transformer attention head.
query
=
query
*
self
.
_
scale
query
=
query
*
scale
else
:
else
:
# Note: we suspect spliting the scale to key, query yields smaller
# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
# For simplicity, we also split when there's no random projection.
key
*=
math
.
sqrt
(
self
.
_
scale
)
key
*=
tf
.
math
.
sqrt
(
scale
)
query
*=
math
.
sqrt
(
self
.
_
scale
)
query
*=
tf
.
math
.
sqrt
(
scale
)
key
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
projection_matrix
)
key
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
projection_matrix
)
query
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
projection_matrix
)
query
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
projection_matrix
)
...
@@ -330,9 +341,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -330,9 +341,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
value: Value `Tensor` of shape `[B, S, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting
attenting
to masked positions. Note that the mask is only appied to
to masked positions. Note that the mask is only appied to
the keys. User
the keys. User
may want to mask the output if query contains pads.
may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
training mode (adding dropout) or in inference mode (doing nothing).
...
@@ -373,9 +384,10 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -373,9 +384,10 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output
=
tf
.
concat
(
attention_output
=
tf
.
concat
(
[
attention_output_softmax
,
attention_output_kernel
],
axis
=
1
)
[
attention_output_softmax
,
attention_output_kernel
],
axis
=
1
)
else
:
else
:
attention_output
=
self
.
_compute_attention
(
attention_output
=
self
.
_compute_attention
(
query
,
key
,
value
,
query
,
key
,
value
,
self
.
_feature_transform
,
self
.
_feature_transform
,
self
.
_is_short_seq
,
attention_mask
,
training
)
self
.
_is_short_seq
,
attention_mask
,
training
)
# This is actually dropping out entire tokens to attend to, which might
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# seem a bit unusual, but is taken from the original Transformer paper.
attention_output
=
self
.
_dropout_layer
(
attention_output
)
attention_output
=
self
.
_dropout_layer
(
attention_output
)
...
...
official/nlp/modeling/layers/kernel_attention_test.py
View file @
d51cc280
...
@@ -30,9 +30,9 @@ _BEGIN_KERNEL = [0, 512]
...
@@ -30,9 +30,9 @@ _BEGIN_KERNEL = [0, 512]
class
KernelAttentionTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
KernelAttentionTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
itertools
.
product
(
@
parameterized
.
parameters
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
_IS_SHORT_SEQ
,
_BEGIN_KERNEL
))
_IS_SHORT_SEQ
,
_BEGIN_KERNEL
))
def
test_attention_projection
(
def
test_attention_projection
(
self
,
feature_transform
,
num_random_features
,
training
,
redraw
,
is_short
,
self
,
feature_transform
,
num_random_features
,
training
,
redraw
,
is_short
,
begin_kernel
):
begin_kernel
):
...
@@ -90,6 +90,32 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -90,6 +90,32 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
training
=
training
)
training
=
training
)
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
@
parameterized
.
parameters
([
128
,
512
])
def
test_attention_scale_by_length
(
self
,
seq_length
):
num_heads
=
12
key_dim
=
64
batch_size
=
2
test_layer
=
attention
.
KernelAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
num_random_features
=
0
,
scale_by_length
=
True
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
value
=
query
encoder_inputs_mask
=
tf
.
ones
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
masks
=
tf
.
cast
(
encoder_inputs_mask
,
dtype
=
tf
.
float32
)
output_scale_by_length
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
)
test_layer
.
_scale_by_length
=
False
output_no_scale_by_length
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
)
if
seq_length
==
512
:
# Equals because log(seq_length, base=512) = 1.0
self
.
assertAllClose
(
output_scale_by_length
,
output_no_scale_by_length
)
else
:
self
.
assertNotAllClose
(
output_scale_by_length
,
output_no_scale_by_length
)
def
test_unsupported_feature_transform
(
self
):
def
test_unsupported_feature_transform
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
'Unsupported feature_transform.*'
):
with
self
.
assertRaisesRegex
(
ValueError
,
'Unsupported feature_transform.*'
):
_
=
attention
.
KernelAttention
(
feature_transform
=
'test'
)
_
=
attention
.
KernelAttention
(
feature_transform
=
'test'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment