Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b3483b39
Commit
b3483b39
authored
Jul 02, 2021
by
Frederick Liu
Committed by
A. Unique TensorFlower
Jul 02, 2021
Browse files
Internal change
PiperOrigin-RevId: 382846192
parent
affbaa60
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
41 deletions
+40
-41
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+2
-0
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+37
-40
official/nlp/modeling/layers/kernel_attention_test.py
official/nlp/modeling/layers/kernel_attention_test.py
+1
-1
No files found.
official/nlp/configs/encoders.py
View file @
b3483b39
...
@@ -161,6 +161,7 @@ class KernelEncoderConfig(hyperparams.Config):
...
@@ -161,6 +161,7 @@ class KernelEncoderConfig(hyperparams.Config):
redraw
:
bool
=
False
redraw
:
bool
=
False
is_short_seq
:
bool
=
False
is_short_seq
:
bool
=
False
begin_kernel
:
int
=
0
begin_kernel
:
int
=
0
scale
:
Optional
[
float
]
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -377,6 +378,7 @@ def build_encoder(config: EncoderConfig,
...
@@ -377,6 +378,7 @@ def build_encoder(config: EncoderConfig,
redraw
=
encoder_cfg
.
redraw
,
redraw
=
encoder_cfg
.
redraw
,
is_short_seq
=
encoder_cfg
.
is_short_seq
,
is_short_seq
=
encoder_cfg
.
is_short_seq
,
begin_kernel
=
encoder_cfg
.
begin_kernel
,
begin_kernel
=
encoder_cfg
.
begin_kernel
,
scale
=
encoder_cfg
.
scale
,
)
)
hidden_cfg
=
dict
(
hidden_cfg
=
dict
(
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
...
...
official/nlp/modeling/layers/kernel_attention.py
View file @
b3483b39
...
@@ -85,30 +85,20 @@ def create_projection_matrix(m, d, seed=None):
...
@@ -85,30 +85,20 @@ def create_projection_matrix(m, d, seed=None):
return
tf
.
linalg
.
matmul
(
tf
.
linalg
.
diag
(
multiplier
),
final_matrix
)
return
tf
.
linalg
.
matmul
(
tf
.
linalg
.
diag
(
multiplier
),
final_matrix
)
def
_generalized_kernel
(
x
,
projection_matrix
,
is_query
,
f
,
h
,
def
_generalized_kernel
(
x
,
projection_matrix
,
f
,
h
):
data_normalizer_fn
=
None
):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args:
Args:
x: The feature being transformed with shape [B, T, N ,H].
x: The feature being transformed with shape [B, T, N ,H].
projection_matrix: The matrix with shape [M, H] that we projecct x to, where
projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections.
M is the number of projections.
is_query: Whether the transform is a query or key. This transform is
symmetric is the argument is not used.
f: A non-linear function applied on x or projected x.
f: A non-linear function applied on x or projected x.
h: A muliplier which is a function of x applied after projected and
h: A muliplier which is a function of x applied after projected and
transformed. Only applied if projection_matrix is not None.
transformed. Only applied if projection_matrix is not None.
data_normalizer_fn: A function which takes x and returns a scalar that
normalize data.
Returns:
Returns:
Transformed feature.
Transformed feature.
"""
"""
# No asymmetric operations.
del
is_query
if
data_normalizer_fn
is
not
None
:
x
=
data_normalizer_fn
(
x
)
if
projection_matrix
is
None
:
if
projection_matrix
is
None
:
return
h
(
x
)
*
f
(
x
)
return
h
(
x
)
*
f
(
x
)
...
@@ -139,9 +129,7 @@ _TRANSFORM_MAP = {
...
@@ -139,9 +129,7 @@ _TRANSFORM_MAP = {
x
-
tf
.
math
.
reduce_max
(
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
x
-
tf
.
math
.
reduce_max
(
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
h
=
lambda
x
:
tf
.
math
.
exp
(
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
reduce_sum
(
-
0.5
*
tf
.
math
.
reduce_sum
(
tf
.
math
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)),
tf
.
math
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)),),
data_normalizer_fn
=
lambda
x
:
x
/
(
tf
.
math
.
sqrt
(
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))))),
"expmod"
:
"expmod"
:
functools
.
partial
(
functools
.
partial
(
_generalized_kernel
,
_generalized_kernel
,
...
@@ -149,15 +137,7 @@ _TRANSFORM_MAP = {
...
@@ -149,15 +137,7 @@ _TRANSFORM_MAP = {
f
=
lambda
x
:
tf
.
math
.
exp
(
f
=
lambda
x
:
tf
.
math
.
exp
(
x
-
tf
.
math
.
reduce_max
(
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
x
-
tf
.
math
.
reduce_max
(
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
h
=
lambda
x
:
tf
.
math
.
exp
(
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))),
-
0.5
*
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))),),
data_normalizer_fn
=
lambda
x
:
x
/
(
tf
.
math
.
sqrt
(
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))))),
"l2"
:
functools
.
partial
(
_generalized_kernel
,
f
=
lambda
x
:
x
,
h
=
lambda
x
:
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
)),
data_normalizer_fn
=
lambda
x
:
x
),
"identity"
:
lambda
x
,
projection_matrix
,
is_query
:
x
"identity"
:
lambda
x
,
projection_matrix
,
is_query
:
x
}
}
# pylint: enable=g-long-lambda
# pylint: enable=g-long-lambda
...
@@ -170,7 +150,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -170,7 +150,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Rethinking Attention with Performers
Rethinking Attention with Performers
(https://arxiv.org/abs/2009.14794)
(https://arxiv.org/abs/2009.14794)
- exp (Lemma 1, positive), relu
, l2
- exp (Lemma 1, positive), relu
- random/deterministic projection
- random/deterministic projection
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
...
@@ -195,14 +175,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -195,14 +175,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
redraw
=
False
,
redraw
=
False
,
is_short_seq
=
False
,
is_short_seq
=
False
,
begin_kernel
=
0
,
begin_kernel
=
0
,
scale
=
None
,
**
kwargs
):
**
kwargs
):
r
"""Constructor of KernelAttention.
r
"""Constructor of KernelAttention.
Args:
Args:
feature_transform: A non-linear transform of the keys and quries.
feature_transform: A non-linear transform of the keys and quries.
Possible transforms are "elu", "relu", "square", "exp", "expmod",
Possible transforms are "elu", "relu", "square", "exp", "expmod",
"l2", "identity". If <is_short_seq> = True, it is recommended to choose
"identity".
feature_transform as "l2".
num_random_features: Number of random features to be used for projection.
num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform.
if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the
seed: The seed to begin drawing random features. Once the seed is set, the
...
@@ -216,6 +196,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -216,6 +196,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
(default option).
(default option).
begin_kernel: Apply kernel_attention after this sequence id and apply
begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this.
softmax attention before this.
scale: The value to scale the dot product as described in `Attention Is
All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
**kwargs: The same arguments `MultiHeadAttention` layer.
**kwargs: The same arguments `MultiHeadAttention` layer.
"""
"""
if
feature_transform
not
in
_TRANSFORM_MAP
:
if
feature_transform
not
in
_TRANSFORM_MAP
:
...
@@ -234,8 +216,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -234,8 +216,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
# 1. inference
# 1. inference
# 2. no redraw
# 2. no redraw
self
.
_seed
=
seed
self
.
_seed
=
seed
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
if
scale
is
None
:
self
.
_scale
=
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
))
else
:
self
.
_scale
=
scale
self
.
_projection_matrix
=
None
self
.
_projection_matrix
=
None
if
num_random_features
>
0
:
if
num_random_features
>
0
:
self
.
_projection_matrix
=
create_projection_matrix
(
self
.
_projection_matrix
=
create_projection_matrix
(
...
@@ -275,6 +260,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -275,6 +260,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Returns:
Returns:
attention_output: Multi-headed outputs of attention computation.
attention_output: Multi-headed outputs of attention computation.
"""
"""
if
is_short_seq
:
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query
=
query
*
self
.
_scale
if
attention_mask
is
not
None
:
key
=
tf
.
einsum
(
"BSNH,BS->BSNH"
,
key
,
attention_mask
)
attention_scores
=
tf
.
einsum
(
"BTNH,BSNH->BTSN"
,
query
,
key
)
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
axis
=
2
)
attention_output
=
tf
.
einsum
(
"BTSN,BSNH->BTNH"
,
attention_scores
,
value
)
return
attention_output
projection_matrix
=
None
projection_matrix
=
None
if
self
.
_num_random_features
>
0
:
if
self
.
_num_random_features
>
0
:
...
@@ -284,23 +280,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -284,23 +280,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else
:
else
:
projection_matrix
=
self
.
_projection_matrix
projection_matrix
=
self
.
_projection_matrix
key
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
projection_matrix
,
False
)
# Note: we suspect spliting the scale to key, query yields smaller
query
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
projection_matrix
,
True
)
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key
*=
math
.
sqrt
(
self
.
_scale
)
query
*=
math
.
sqrt
(
self
.
_scale
)
key
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
projection_matrix
)
query
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
projection_matrix
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
key
=
tf
.
einsum
(
"BSNH,BS->BSNH"
,
key
,
attention_mask
)
key
=
tf
.
einsum
(
"BSNH,BS->BSNH"
,
key
,
attention_mask
)
if
is_short_seq
:
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key
,
value
)
attention_scores
=
tf
.
einsum
(
"BTNH,BSNH->BTSN"
,
query
,
key
)
denominator
=
1.0
/
(
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
axis
=
2
)
tf
.
einsum
(
"BTNH,BNH->BTN"
,
query
,
tf
.
reduce_sum
(
key
,
axis
=
1
))
+
attention_output
=
tf
.
einsum
(
"BTSN,BSNH->BTNH"
,
attention_scores
,
value
)
_NUMERIC_STABLER
)
return
attention_output
return
tf
.
einsum
(
"BTNH,BNDH,BTN->BTND"
,
query
,
kv
,
denominator
)
else
:
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key
,
value
)
denominator
=
1.0
/
(
tf
.
einsum
(
"BTNH,BNH->BTN"
,
query
,
tf
.
reduce_sum
(
key
,
axis
=
1
))
+
_NUMERIC_STABLER
)
return
tf
.
einsum
(
"BTNH,BNDH,BTN->BTND"
,
query
,
kv
,
denominator
)
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
super
().
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
super
().
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
...
@@ -391,6 +387,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -391,6 +387,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
"redraw"
:
self
.
_redraw
,
"redraw"
:
self
.
_redraw
,
"is_short_seq"
:
self
.
_is_short_seq
,
"is_short_seq"
:
self
.
_is_short_seq
,
"begin_kernel"
:
self
.
_begin_kernel
,
"begin_kernel"
:
self
.
_begin_kernel
,
"scale"
:
self
.
_scale
,
}
}
base_config
=
super
().
get_config
()
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
official/nlp/modeling/layers/kernel_attention_test.py
View file @
b3483b39
...
@@ -21,7 +21,7 @@ import tensorflow as tf
...
@@ -21,7 +21,7 @@ import tensorflow as tf
from
official.nlp.modeling.layers
import
kernel_attention
as
attention
from
official.nlp.modeling.layers
import
kernel_attention
as
attention
_FEATURE_TRANSFORM
=
[
'relu'
,
'elu'
,
'exp'
,
'l2'
]
_FEATURE_TRANSFORM
=
[
'relu'
,
'elu'
,
'exp'
]
_REDRAW
=
[
True
,
False
]
_REDRAW
=
[
True
,
False
]
_TRAINING
=
[
True
,
False
]
_TRAINING
=
[
True
,
False
]
_IS_SHORT_SEQ
=
[
True
,
False
]
_IS_SHORT_SEQ
=
[
True
,
False
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment