Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d51cc280
"sgl-kernel/python/vscode:/vscode.git/clone" did not exist on "3db35c1af4a91479da526d8d3b1adadcfdeba054"
Commit
d51cc280
authored
Apr 12, 2022
by
Jialu Liu
Committed by
A. Unique TensorFlower
Apr 12, 2022
Browse files
Internal change
PiperOrigin-RevId: 441337748
parent
54659689
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
20 deletions
+58
-20
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+29
-17
official/nlp/modeling/layers/kernel_attention_test.py
official/nlp/modeling/layers/kernel_attention_test.py
+29
-3
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
d51cc280
...
...
@@ -178,13 +178,13 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
is_short_seq
=
False
,
begin_kernel
=
0
,
scale
=
None
,
scale_by_length
=
False
,
**
kwargs
):
r
"""Constructor of KernelAttention.
Args:
feature_transform: A non-linear transform of the keys and quries.
Possible transforms are "elu", "relu", "square", "exp", "expmod",
"identity".
feature_transform: A non-linear transform of the keys and quries. Possible
transforms are "elu", "relu", "square", "exp", "expmod", "identity".
num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the
...
...
@@ -194,12 +194,16 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
redraw: Whether to redraw projection every forward pass during training.
The argument is only effective when num_random_features > 0.
is_short_seq: boolean predicate indicating whether input data consists of
very short sequences or not; in most cases this should be False
(default
option).
very short sequences or not; in most cases this should be False
(default
option).
begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this.
scale: The value to scale the dot product as described in `Attention Is
All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
scale_by_length: boolean predicate indicating whether additionally scale
the dot product based on key length. Set as log_512^(n) to stablize
attention entropy against length. Refer to
https://kexue.fm/archives/8823 for details.
**kwargs: The same arguments `MultiHeadAttention` layer.
"""
if
feature_transform
not
in
_TRANSFORM_MAP
:
...
...
@@ -214,6 +218,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self
.
_redraw
=
redraw
self
.
_is_short_seq
=
is_short_seq
self
.
_begin_kernel
=
begin_kernel
self
.
_scale_by_length
=
scale_by_length
# We use the seed for two scenarios:
# 1. inference
# 2. no redraw
...
...
@@ -252,9 +257,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
is_short_seq: boolean predicate indicating whether input data consists of
short or long sequences; usually short sequence is defined as having
length L <= 1024.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting
to masked positions. Note that the mask is only appied to
the keys. User
may want to mask the output if query contains pads.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting
to masked positions. Note that the mask is only appied to
the keys. User
may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
numeric_stabler: A scalar value added to avoid divide by 0.
...
...
@@ -270,17 +275,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else
:
projection_matrix
=
self
.
_projection_matrix
if
self
.
_scale_by_length
:
scale
=
tf
.
math
.
log
(
tf
.
reduce_sum
(
attention_mask
,
axis
=-
1
))
*
self
.
_scale
/
math
.
log
(
512
)
scale
=
tf
.
reshape
(
scale
,
[
-
1
,
1
,
1
,
1
])
else
:
scale
=
self
.
_scale
if
is_short_seq
:
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query
=
query
*
self
.
_
scale
query
=
query
*
scale
else
:
# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key
*=
math
.
sqrt
(
self
.
_
scale
)
query
*=
math
.
sqrt
(
self
.
_
scale
)
key
*=
tf
.
math
.
sqrt
(
scale
)
query
*=
tf
.
math
.
sqrt
(
scale
)
key
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
projection_matrix
)
query
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
projection_matrix
)
...
...
@@ -330,9 +341,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting
to masked positions. Note that the mask is only appied to
the keys. User
may want to mask the output if query contains pads.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting
to masked positions. Note that the mask is only appied to
the keys. User
may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
...
...
@@ -373,9 +384,10 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output
=
tf
.
concat
(
[
attention_output_softmax
,
attention_output_kernel
],
axis
=
1
)
else
:
attention_output
=
self
.
_compute_attention
(
query
,
key
,
value
,
self
.
_feature_transform
,
self
.
_is_short_seq
,
attention_mask
,
training
)
attention_output
=
self
.
_compute_attention
(
query
,
key
,
value
,
self
.
_feature_transform
,
self
.
_is_short_seq
,
attention_mask
,
training
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_output
=
self
.
_dropout_layer
(
attention_output
)
...
...
official/nlp/modeling/layers/kernel_attention_test.py
View file @
d51cc280
...
...
@@ -30,8 +30,8 @@ _BEGIN_KERNEL = [0, 512]
class
KernelAttentionTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
_IS_SHORT_SEQ
,
_BEGIN_KERNEL
))
def
test_attention_projection
(
self
,
feature_transform
,
num_random_features
,
training
,
redraw
,
is_short
,
...
...
@@ -90,6 +90,32 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
training
=
training
)
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
@
parameterized
.
parameters
([
128
,
512
])
def
test_attention_scale_by_length
(
self
,
seq_length
):
num_heads
=
12
key_dim
=
64
batch_size
=
2
test_layer
=
attention
.
KernelAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
num_random_features
=
0
,
scale_by_length
=
True
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
value
=
query
encoder_inputs_mask
=
tf
.
ones
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
masks
=
tf
.
cast
(
encoder_inputs_mask
,
dtype
=
tf
.
float32
)
output_scale_by_length
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
)
test_layer
.
_scale_by_length
=
False
output_no_scale_by_length
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
)
if
seq_length
==
512
:
# Equals because log(seq_length, base=512) = 1.0
self
.
assertAllClose
(
output_scale_by_length
,
output_no_scale_by_length
)
else
:
self
.
assertNotAllClose
(
output_scale_by_length
,
output_no_scale_by_length
)
def
test_unsupported_feature_transform
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
'Unsupported feature_transform.*'
):
_
=
attention
.
KernelAttention
(
feature_transform
=
'test'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment