Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3eed3e03
Commit
3eed3e03
authored
Aug 08, 2022
by
Krzysztof Choromanski
Committed by
A. Unique TensorFlower
Aug 08, 2022
Browse files
Improving integration of the FAVOR++ mechanism with the test of the Performer's code.
PiperOrigin-RevId: 466056003
parent
cb7ae42f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
40 deletions
+41
-40
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+41
-40
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
3eed3e03
...
...
@@ -49,11 +49,10 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
axis: Axis to pad along.
chunk_length: The output tensor will have shape[axis] divisible by
chunk_length.
padding: Pad the input tensor across the axis from either left or
right if padding is set to "left" or "right"; applies no padding
if padding is set to None. In the latter case, the axis
dimension of the input tensor must be divisible by the
chunk_length.
padding: Pad the input tensor across the axis from either left or right if
padding is set to "left" or "right"; applies no padding if padding is set
to None. In the latter case, the axis dimension of the input tensor must
be divisible by the chunk_length.
Returns:
Padded tensor with shape[axis] divisible by chunk_length.
...
...
@@ -73,10 +72,11 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
else
:
raise
ValueError
(
"Illegal padding value; must be one of
\"
left
\"
,
\"
right
\"
or None."
)
paddings
=
tf
.
concat
(
[
tf
.
zeros
([
axis
,
2
],
dtype
=
tf
.
int32
),
axis_paddings
,
tf
.
zeros
([
rank
-
axis
-
1
,
2
],
dtype
=
tf
.
int32
)],
axis
=
0
)
paddings
=
tf
.
concat
([
tf
.
zeros
([
axis
,
2
],
dtype
=
tf
.
int32
),
axis_paddings
,
tf
.
zeros
([
rank
-
axis
-
1
,
2
],
dtype
=
tf
.
int32
)
],
axis
=
0
)
return
tf
.
pad
(
tensor
,
paddings
)
...
...
@@ -94,7 +94,7 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
shape
=
tf
.
shape
(
tensor
)
num_chunks
=
shape
[
axis
]
//
chunk_length
new_shape
=
tf
.
concat
(
[
shape
[:
axis
],
[
num_chunks
,
chunk_length
],
shape
[(
axis
+
1
):]],
axis
=
0
)
[
shape
[:
axis
],
[
num_chunks
,
chunk_length
],
shape
[(
axis
+
1
):]],
axis
=
0
)
return
tf
.
reshape
(
tensor
,
new_shape
)
...
...
@@ -128,8 +128,7 @@ def weighted_window_sum(tensor, window_length, window_weights):
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the window.
window_weights: Tensor of shape [window_length] containing window
weights.
window_weights: Tensor of shape [window_length] containing window weights.
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
...
...
@@ -196,14 +195,13 @@ def causal_windowed_performer_attention(query_matrix,
value_matrix: Value `Tensor` of shape `[B, T, H, out_dim]`.
chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks.
window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this factor
before summation.
padding: Pad the query, value and key input tensors across the
axis from either left or right if padding is set to "left" or
"right"; apply no padding if padding is set to None. In the
latter case, the axis dimension of the query, value and key
input tensors must be divisible by the chunk_length.
window_decay: Float window decay factor or `None`. If set, exponentially
decay past attention window values by this factor before summation.
padding: Pad the query, value and key input tensors across the axis from
either left or right if padding is set to "left" or "right"; apply no
padding if padding is set to None. In the latter case, the axis dimension
of the query, value and key input tensors must be divisible by the
chunk_length.
Returns:
Window causal performer attention of shape `[B, T, H, out_dim]`.
...
...
@@ -302,11 +300,13 @@ def create_projection_matrix(m, d, seed=None):
return
tf
.
linalg
.
matmul
(
tf
.
linalg
.
diag
(
multiplier
),
final_matrix
)
def
_generalized_kernel
(
x
,
projection_matrix
,
f
,
h
):
def
_generalized_kernel
(
x
,
y
,
is_query
,
projection_matrix
,
f
,
h
):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args:
x: The feature being transformed with shape [B, T, N ,H].
y: The extra stats-tensor of shape [B, T, N ,H].
is_query: True if x is a query-tensor.
projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections.
f: A non-linear function applied on x or projected x.
...
...
@@ -316,7 +316,8 @@ def _generalized_kernel(x, projection_matrix, f, h):
Returns:
Transformed feature.
"""
del
y
del
is_query
if
projection_matrix
is
None
:
return
h
(
x
)
*
f
(
x
)
else
:
...
...
@@ -475,6 +476,8 @@ _TRANSFORM_MAP = {
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))),
),
"expplus"
:
expplus
,
"identity"
:
functools
.
partial
(
_generalized_kernel
,
f
=
lambda
x
:
x
,
h
=
lambda
x
:
1
)
}
...
...
@@ -554,18 +557,16 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
causal_chunk_length: Length of each chunk in tokens.
causal_window_length: Length of attention window in chunks.
causal_window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this
factor before summation.
causal_padding: Pad the query, value and key input tensors
across the axis from either left or right if padding is set to
"left" or "right"; apply no padding if padding is set to None.
In the latter case, the axis dimension of the query, value and
key input tensors must be divisible by the chunk_length.
**kwargs:
The same arguments `MultiHeadAttention` layer.
exponentially decay past attention window values by this factor before
summation.
causal_padding: Pad the query, value and key input tensors across the axis
from either left or right if padding is set to "left" or "right"; apply
no padding if padding is set to None. In the latter case, the axis
dimension of the query, value and key input tensors must be divisible by
the chunk_length.
**kwargs: The same arguments `MultiHeadAttention` layer.
"""
if
(
feature_transform
not
in
_TRANSFORM_MAP
and
feature_transform
!=
"expplus"
):
if
feature_transform
not
in
_TRANSFORM_MAP
:
raise
ValueError
(
"Unsupported feature_transform. The supported "
"feature_transform are %s. "
"Got '%s'."
%
(
_TRANSFORM_MAP
.
keys
(),
feature_transform
))
...
...
@@ -661,12 +662,10 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
key
*=
tf
.
math
.
sqrt
(
scale
)
query
*=
tf
.
math
.
sqrt
(
scale
)
if
feature_transform
!=
"expplus"
:
key_prime
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
projection_matrix
)
query_prime
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
projection_matrix
)
else
:
key_prime
=
expplus
(
key
,
query
,
False
,
projection_matrix
)
query_prime
=
expplus
(
query
,
key
,
True
,
projection_matrix
)
key_prime
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
query
,
False
,
projection_matrix
)
query_prime
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
key
,
True
,
projection_matrix
)
if
attention_mask
is
not
None
:
key_prime
=
tf
.
einsum
(
"BSNH,BS->BSNH"
,
key_prime
,
attention_mask
)
...
...
@@ -677,7 +676,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output
=
tf
.
einsum
(
"BTSN,BSNH->BTNH"
,
attention_scores
,
value
)
elif
self
.
use_causal_windowed
:
attention_output
=
causal_windowed_performer_attention
(
query_prime
,
key_prime
,
value
,
query_prime
,
key_prime
,
value
,
chunk_length
=
self
.
causal_chunk_length
,
window_length
=
self
.
causal_window_length
,
window_decay
=
self
.
causal_window_decay
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment