Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0d62382b
Commit
0d62382b
authored
May 18, 2021
by
Frederick Liu
Committed by
A. Unique TensorFlower
May 18, 2021
Browse files
[efficient] Opensource kernel attention to modeling/layers.
PiperOrigin-RevId: 374472500
parent
09a70c7c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
492 additions
and
0 deletions
+492
-0
official/nlp/modeling/layers/README.md
official/nlp/modeling/layers/README.md
+9
-0
official/nlp/modeling/layers/__init__.py
official/nlp/modeling/layers/__init__.py
+1
-0
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+363
-0
official/nlp/modeling/layers/kernel_attention_test.py
official/nlp/modeling/layers/kernel_attention_test.py
+119
-0
No files found.
official/nlp/modeling/layers/README.md
View file @
0d62382b
...
...
@@ -15,6 +15,15 @@ assemble new `tf.keras` layers or models.
*
[
CachedAttention
](
attention.py
)
implements an attention layer with cache
used for auto-agressive decoding.
*
[
KernelAttention
](
kernel_attention.py
)
implements a group of attention
mechansim that express the self-attention as a linear dot-product of
kernel feature maps and make use of the associativity property of
matrix products to reduce the complexity from quadratic to linear. The
implementation includes methods described in
[
"Transformers are RNNs:
Fast Autoregressive Transformers with Linear Attention"
](
https://arxiv.org/abs/2006.16236
)
,
[
"Rethinking Attention with Performers"
](
https://arxiv.org/abs/2009.14794
)
,
[
"Random Feature Attention"
](
https://openreview.net/pdf?id=QtTKTdVrFBB
)
.
*
[
MatMulWithMargin
](
mat_mul_with_margin.py
)
implements a matrix
multiplication with margin layer used for training retrieval / ranking
tasks, as described in
[
"Improving Multilingual Sentence Embedding using
...
...
official/nlp/modeling/layers/__init__.py
View file @
0d62382b
...
...
@@ -24,6 +24,7 @@ from official.nlp.modeling.layers.cls_head import *
from
official.nlp.modeling.layers.dense_einsum
import
DenseEinsum
from
official.nlp.modeling.layers.gated_feedforward
import
GatedFeedforward
from
official.nlp.modeling.layers.gaussian_process
import
RandomFeatureGaussianProcess
from
official.nlp.modeling.layers.kernel_attention
import
KernelAttention
from
official.nlp.modeling.layers.masked_lm
import
MaskedLM
from
official.nlp.modeling.layers.masked_softmax
import
MaskedSoftmax
from
official.nlp.modeling.layers.mat_mul_with_margin
import
MatMulWithMargin
...
...
official/nlp/modeling/layers/kernel_attention.py
0 → 100644
View file @
0d62382b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based kernel attention layer."""
import
functools
import
math
import
tensorflow
as
tf
_NUMERIC_STABLER
=
1e-6
def
create_projection_matrix
(
m
,
d
,
seed
=
None
):
r
"""Constructs the matrix of random projections.
Constructs a matrix of random orthogonal projections. Each projection vector
has direction chosen uniformly at random length taken from the
\chi(d) distribution.).
Args:
m: number of random projections.
d: dimensionality of each random projection.
seed: random seed used to construct projections. If not, we use the stateful
api.
Returns:
The matrix of random projections of the shape [m, d].
"""
nb_full_blocks
=
math
.
ceil
(
m
/
d
)
block_list
=
tf
.
TensorArray
(
tf
.
float32
,
size
=
tf
.
cast
(
nb_full_blocks
,
dtype
=
tf
.
int32
))
stateful
=
False
if
seed
is
None
:
stateful
=
True
# dummy seed to make sure the graph compiles though the path is not taken.
seed
=
tf
.
constant
([
0
,
1
])
current_seed
=
seed
for
i
in
range
(
nb_full_blocks
):
if
stateful
:
unstructured_block
=
tf
.
random
.
normal
((
d
,
d
))
else
:
unstructured_block
=
tf
.
random
.
stateless_normal
((
d
,
d
),
seed
=
current_seed
)
current_seed
=
tf
.
random
.
stateless_uniform
([
2
],
seed
=
current_seed
,
minval
=
None
,
dtype
=
tf
.
int32
)
q
,
_
=
tf
.
linalg
.
qr
(
unstructured_block
)
q
=
tf
.
transpose
(
q
)
block_list
=
block_list
.
write
(
i
,
q
)
final_matrix
=
block_list
.
concat
()[:
m
]
if
stateful
is
None
:
multiplier
=
tf
.
norm
(
tf
.
random
.
normal
((
m
,
d
)),
axis
=
1
)
else
:
multiplier
=
tf
.
norm
(
tf
.
random
.
stateless_normal
((
m
,
d
),
seed
=
current_seed
),
axis
=
1
)
return
tf
.
linalg
.
matmul
(
tf
.
linalg
.
diag
(
multiplier
),
final_matrix
)
def
_generalized_kernel
(
x
,
projection_matrix
,
is_query
,
f
,
h
,
data_normalizer_fn
=
None
):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args:
x: The feature being transformed with shape [B, T, N ,H].
projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections.
is_query: Whether the transform is a query or key. This transform is
symmetric is the argument is not used.
f: A non-linear function applied on x or projected x.
h: A muliplier which is a function of x applied after projected and
transformed. Only applied if projection_matrix is not None.
data_normalizer_fn: A function which takes x and returns a scalar that
normalize data.
Returns:
Transformed feature.
"""
# No asymmetric operations.
del
is_query
if
data_normalizer_fn
is
not
None
:
x
=
data_normalizer_fn
(
x
)
if
projection_matrix
is
None
:
return
h
(
x
)
*
f
(
x
)
else
:
x_projected
=
tf
.
einsum
(
"BTNH,MH->BTNM"
,
x
,
projection_matrix
)
return
h
(
x
)
*
f
(
x_projected
)
/
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
projection_matrix
)[
0
],
tf
.
float32
))
# pylint: disable=g-long-lambda
_TRANSFORM_MAP
=
{
"elu"
:
functools
.
partial
(
_generalized_kernel
,
f
=
lambda
x
:
tf
.
keras
.
activations
.
elu
(
x
)
+
1
,
h
=
lambda
x
:
1
),
"relu"
:
functools
.
partial
(
_generalized_kernel
,
f
=
tf
.
keras
.
activations
.
relu
,
h
=
lambda
x
:
1
),
"square"
:
functools
.
partial
(
_generalized_kernel
,
f
=
tf
.
math
.
square
,
h
=
lambda
x
:
1
),
"exp"
:
functools
.
partial
(
_generalized_kernel
,
# Avoid exp explosion by shifting.
f
=
lambda
x
:
tf
.
math
.
exp
(
x
-
tf
.
math
.
reduce_max
(
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
reduce_sum
(
tf
.
math
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)),
data_normalizer_fn
=
lambda
x
:
x
/
(
tf
.
math
.
sqrt
(
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))))),
"expmod"
:
functools
.
partial
(
_generalized_kernel
,
# Avoid exp explosion by shifting.
f
=
lambda
x
:
tf
.
math
.
exp
(
x
-
tf
.
math
.
reduce_max
(
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))),
data_normalizer_fn
=
lambda
x
:
x
/
(
tf
.
math
.
sqrt
(
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))))),
"l2"
:
functools
.
partial
(
_generalized_kernel
,
f
=
lambda
x
:
x
,
h
=
lambda
x
:
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
)),
data_normalizer_fn
=
lambda
x
:
x
),
"identity"
:
lambda
x
,
projection_matrix
,
is_query
:
x
}
# pylint: enable=g-long-lambda
class
KernelAttention
(
tf
.
keras
.
layers
.
MultiHeadAttention
):
"""A variant of efficient transformers which replaces softmax with kernels.
This module combines ideas from the two following papers:
Rethinking Attention with Performers
(https://arxiv.org/abs/2009.14794)
- exp (Lemma 1, positive), relu, l2
- random/deterministic projection
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
(https://arxiv.org/abs/2006.16236)
- elu
with the theory of approximating angular Performer kernels from go/performer.
The module enables computing efficient attention in both: long sequence and
shorter sequence regimes. In the former setting, the attention matrix is never
explicitly computed and instead its low-rank decomposition obtained with given
kernel feature maps is leveraged to conduct attention module calculations
(see: https://arxiv.org/abs/2006.16236). In the latter setting, attention
matrix is constructed, but kernel features providing dimensionality reduction
are applied, resulting in more efficient computation of the attention matrix.
"""
def
__init__
(
self
,
feature_transform
=
"exp"
,
num_random_features
=
256
,
seed
=
0
,
redraw
=
False
,
is_short_seq
=
False
,
begin_kernel
=
0
,
**
kwargs
):
r
"""Constructor of KernelAttention.
Args:
feature_transform: A non-linear transform of the keys and quries.
Possible transforms are "elu", "relu", "square", "exp", "expmod",
"l2", "identity". If <is_short_seq> = True, it is recommended to choose
feature_transform as "l2".
num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the
psedo number generation is determinisitc. Users should pass different
seed for different layers. For multi-worker, each layer will use the
same projection at each step.
redraw: Whether to redraw projection every forward pass during training.
The argument is only effective when num_random_features > 0.
is_short_seq: boolean predicate indicating whether input data consists of
very short sequences or not; in most cases this should be False
(default option).
begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this.
**kwargs: The same arguments `MultiHeadAttention` layer.
"""
if
feature_transform
not
in
_TRANSFORM_MAP
:
raise
ValueError
(
"Unsupported feature_transform. The supported "
"feature_transform are %s. "
"Got '%s'."
%
(
_TRANSFORM_MAP
.
keys
(),
feature_transform
))
if
num_random_features
<=
0
and
redraw
:
raise
ValueError
(
"There is nothing to redraw when num_random_features <= 0."
)
self
.
_feature_transform
=
feature_transform
self
.
_num_random_features
=
num_random_features
self
.
_redraw
=
redraw
self
.
_is_short_seq
=
is_short_seq
self
.
_begin_kernel
=
begin_kernel
# We use the seed for two scenarios:
# 1. inference
# 2. no redraw
self
.
_seed
=
seed
super
().
__init__
(
**
kwargs
)
self
.
_projection_matrix
=
None
if
num_random_features
>
0
:
self
.
_projection_matrix
=
create_projection_matrix
(
self
.
_num_random_features
,
self
.
_key_dim
,
tf
.
constant
([
self
.
_seed
,
self
.
_seed
+
1
]))
def
_compute_attention
(
self
,
query
,
key
,
value
,
feature_transform
,
is_short_seq
,
attention_mask
=
None
,
training
=
False
,
numeric_stabler
=
_NUMERIC_STABLER
):
"""Applies kernel attention with query, key, value tensors.
This function defines the computation inside `call` with projected
multi-head Q, K, V inputs. Users can override this function for customized
attention implementation.
Args:
query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
key: Projected key `Tensor` of shape `[B, S, N, key_dim]`.
value: Projected value `Tensor` of shape `[B, S, N, value_dim]`.
feature_transform: A non-linear transform of the keys and quries.
is_short_seq: boolean predicate indicating whether input data consists of
short or long sequences; usually short sequence is defined as having
length L <= 1024.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attention to certain positions. Note that the mask is only appied to
the keys. User may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
numeric_stabler: A scalar value added to avoid divide by 0.
Returns:
attention_output: Multi-headed outputs of attention computation.
"""
projection_matrix
=
None
if
self
.
_num_random_features
>
0
:
if
self
.
_redraw
and
training
:
projection_matrix
=
create_projection_matrix
(
self
.
_num_random_features
,
self
.
_key_dim
)
else
:
projection_matrix
=
self
.
_projection_matrix
key
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
projection_matrix
,
False
)
query
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
projection_matrix
,
True
)
if
attention_mask
is
not
None
:
key
=
tf
.
einsum
(
"BSNH,BS->BSNH"
,
key
,
attention_mask
)
if
is_short_seq
:
attention_scores
=
tf
.
einsum
(
"BTNH,BSNH->BTSN"
,
query
,
key
)
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
axis
=
2
)
attention_output
=
tf
.
einsum
(
"BTSN,BSNH->BTNH"
,
attention_scores
,
value
)
return
attention_output
else
:
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key
,
value
)
denominator
=
1.0
/
(
tf
.
einsum
(
"BTNH,BNH->BTN"
,
query
,
tf
.
reduce_sum
(
key
,
axis
=
1
))
+
_NUMERIC_STABLER
)
return
tf
.
einsum
(
"BTNH,BNDH,BTN->BTND"
,
query
,
kv
,
denominator
)
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
super
().
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
if
self
.
_begin_kernel
>
0
:
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
self
.
_output_dense_softmax
=
self
.
_make_output_dense
(
self
.
_query_shape
.
rank
-
1
,
common_kwargs
,
name
=
"attention_output_softmax"
)
self
.
_dropout_softmax
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
def
call
(
self
,
query
,
value
,
key
=
None
,
attention_mask
=
None
,
training
=
False
,
**
kwargs
):
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
if
key
is
None
:
key
=
value
# N = `num_attention_heads`
# H = `size_per_head`
# `query` = [B, T, N ,H]
query
=
self
.
_query_dense
(
query
)
# `key` = [B, S, N, H]
key
=
self
.
_key_dense
(
key
)
# `value` = [B, S, N, D]
value
=
self
.
_value_dense
(
value
)
if
self
.
_begin_kernel
>
0
:
attention_output_softmax
=
self
.
_compute_attention
(
query
[:,
:
self
.
_begin_kernel
],
key
,
value
,
"identity"
,
True
,
attention_mask
,
training
)
attention_output_softmax
=
self
.
_dropout_softmax
(
attention_output_softmax
)
attention_output_softmax
=
self
.
_output_dense_softmax
(
attention_output_softmax
)
attention_output_kernel
=
self
.
_compute_attention
(
query
[:,
self
.
_begin_kernel
:],
key
,
value
,
self
.
_feature_transform
,
self
.
_is_short_seq
,
attention_mask
,
training
)
attention_output_kernel
=
self
.
_dropout_layer
(
attention_output_kernel
)
attention_output_kernel
=
self
.
_output_dense
(
attention_output_kernel
)
attention_output
=
tf
.
concat
(
[
attention_output_softmax
,
attention_output_kernel
],
axis
=
1
)
else
:
attention_output
=
self
.
_compute_attention
(
query
,
key
,
value
,
self
.
_feature_transform
,
self
.
_is_short_seq
,
attention_mask
,
training
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_output
=
self
.
_dropout_layer
(
attention_output
)
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
def
get_config
(
self
):
config
=
{
"feature_transform"
:
self
.
_feature_transform
,
"num_random_features"
:
self
.
_num_random_features
,
"seed"
:
self
.
_seed
,
"redraw"
:
self
.
_redraw
,
"is_short_seq"
:
self
.
_is_short_seq
,
"begin_kernel"
:
self
.
_begin_kernel
,
}
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
official/nlp/modeling/layers/kernel_attention_test.py
0 → 100644
View file @
0d62382b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.kernel.attention."""
import
itertools
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
kernel_attention
as
attention
_FEATURE_TRANSFORM
=
[
'relu'
,
'elu'
,
'exp'
,
'l2'
]
_REDRAW
=
[
True
,
False
]
_TRAINING
=
[
True
,
False
]
_IS_SHORT_SEQ
=
[
True
,
False
]
_BEGIN_KERNEL
=
[
0
,
512
]
class
KernelAttentionTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
_IS_SHORT_SEQ
,
_BEGIN_KERNEL
))
def
test_attention_projection
(
self
,
feature_transform
,
num_random_features
,
training
,
redraw
,
is_short
,
begin_kernel
):
num_heads
=
12
key_dim
=
64
seq_length
=
1024
batch_size
=
2
test_layer
=
attention
.
KernelAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
feature_transform
=
feature_transform
,
num_random_features
=
num_random_features
,
redraw
=
redraw
,
is_short_seq
=
is_short
,
begin_kernel
=
begin_kernel
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
value
=
query
encoder_inputs_mask
=
tf
.
zeros
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
masks
=
tf
.
cast
(
encoder_inputs_mask
,
dtype
=
tf
.
float32
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
,
training
=
training
)
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
0
],
_TRAINING
,
[
False
],
_IS_SHORT_SEQ
,
_BEGIN_KERNEL
))
def
test_attention_no_projection
(
self
,
feature_transform
,
num_random_features
,
training
,
redraw
,
is_short
,
begin_kernel
):
num_heads
=
12
key_dim
=
64
seq_length
=
1024
batch_size
=
2
test_layer
=
attention
.
KernelAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
feature_transform
=
feature_transform
,
num_random_features
=
num_random_features
,
redraw
=
redraw
,
is_short_seq
=
is_short
,
begin_kernel
=
begin_kernel
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
value
=
query
encoder_inputs_mask
=
tf
.
zeros
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
masks
=
tf
.
cast
(
encoder_inputs_mask
,
dtype
=
tf
.
float32
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
,
training
=
training
)
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
def
test_unsupported_feature_transform
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
'Unsupported feature_transform.*'
):
_
=
attention
.
KernelAttention
(
feature_transform
=
'test'
)
def
test_redraw_true_no_projection
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
'There is nothing to redraw when num_random_features.*'
):
_
=
attention
.
KernelAttention
(
num_heads
=
2
,
key_dim
=
64
,
feature_transform
=
'elu'
,
num_random_features
=
0
,
redraw
=
True
)
def
test_config
(
self
):
num_heads
=
12
key_dim
=
64
test_layer
=
attention
.
KernelAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
feature_transform
=
'exp'
,
num_random_features
=
128
,
is_short_seq
=
True
)
new_layer
=
attention
.
KernelAttention
.
from_config
(
test_layer
.
get_config
())
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
test_layer
.
get_config
(),
new_layer
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment