Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3eed3e03
Commit
3eed3e03
authored
Aug 08, 2022
by
Krzysztof Choromanski
Committed by
A. Unique TensorFlower
Aug 08, 2022
Browse files
Improving integration of the FAVOR++ mechanism with the test of the Performer's code.
PiperOrigin-RevId: 466056003
parent
cb7ae42f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
40 deletions
+41
-40
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+41
-40
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
3eed3e03
...
@@ -49,11 +49,10 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
...
@@ -49,11 +49,10 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
axis: Axis to pad along.
axis: Axis to pad along.
chunk_length: The output tensor will have shape[axis] divisible by
chunk_length: The output tensor will have shape[axis] divisible by
chunk_length.
chunk_length.
padding: Pad the input tensor across the axis from either left or
padding: Pad the input tensor across the axis from either left or right if
right if padding is set to "left" or "right"; applies no padding
padding is set to "left" or "right"; applies no padding if padding is set
if padding is set to None. In the latter case, the axis
to None. In the latter case, the axis dimension of the input tensor must
dimension of the input tensor must be divisible by the
be divisible by the chunk_length.
chunk_length.
Returns:
Returns:
Padded tensor with shape[axis] divisible by chunk_length.
Padded tensor with shape[axis] divisible by chunk_length.
...
@@ -73,10 +72,11 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
...
@@ -73,10 +72,11 @@ def pad_to_chunk_length(tensor, axis, chunk_length, padding=None):
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Illegal padding value; must be one of
\"
left
\"
,
\"
right
\"
or None."
)
"Illegal padding value; must be one of
\"
left
\"
,
\"
right
\"
or None."
)
paddings
=
tf
.
concat
(
paddings
=
tf
.
concat
([
[
tf
.
zeros
([
axis
,
2
],
dtype
=
tf
.
int32
),
tf
.
zeros
([
axis
,
2
],
dtype
=
tf
.
int32
),
axis_paddings
,
axis_paddings
,
tf
.
zeros
([
rank
-
axis
-
1
,
2
],
dtype
=
tf
.
int32
)
tf
.
zeros
([
rank
-
axis
-
1
,
2
],
dtype
=
tf
.
int32
)],
axis
=
0
)
],
axis
=
0
)
return
tf
.
pad
(
tensor
,
paddings
)
return
tf
.
pad
(
tensor
,
paddings
)
...
@@ -94,7 +94,7 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
...
@@ -94,7 +94,7 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
shape
=
tf
.
shape
(
tensor
)
shape
=
tf
.
shape
(
tensor
)
num_chunks
=
shape
[
axis
]
//
chunk_length
num_chunks
=
shape
[
axis
]
//
chunk_length
new_shape
=
tf
.
concat
(
new_shape
=
tf
.
concat
(
[
shape
[:
axis
],
[
num_chunks
,
chunk_length
],
shape
[(
axis
+
1
):]],
axis
=
0
)
[
shape
[:
axis
],
[
num_chunks
,
chunk_length
],
shape
[(
axis
+
1
):]],
axis
=
0
)
return
tf
.
reshape
(
tensor
,
new_shape
)
return
tf
.
reshape
(
tensor
,
new_shape
)
...
@@ -128,8 +128,7 @@ def weighted_window_sum(tensor, window_length, window_weights):
...
@@ -128,8 +128,7 @@ def weighted_window_sum(tensor, window_length, window_weights):
Args:
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the window.
window_length: The length of the window.
window_weights: Tensor of shape [window_length] containing window
window_weights: Tensor of shape [window_length] containing window weights.
weights.
Returns:
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
A tensor of shape [B, T', C', H, dim] containing sums over the
...
@@ -196,14 +195,13 @@ def causal_windowed_performer_attention(query_matrix,
...
@@ -196,14 +195,13 @@ def causal_windowed_performer_attention(query_matrix,
value_matrix: Value `Tensor` of shape `[B, T, H, out_dim]`.
value_matrix: Value `Tensor` of shape `[B, T, H, out_dim]`.
chunk_length: Length of each chunk in tokens.
chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks.
window_length: Length of attention window in chunks.
window_decay: Float window decay factor or `None`. If set,
window_decay: Float window decay factor or `None`. If set, exponentially
exponentially decay past attention window values by this factor
decay past attention window values by this factor before summation.
before summation.
padding: Pad the query, value and key input tensors across the axis from
padding: Pad the query, value and key input tensors across the
either left or right if padding is set to "left" or "right"; apply no
axis from either left or right if padding is set to "left" or
padding if padding is set to None. In the latter case, the axis dimension
"right"; apply no padding if padding is set to None. In the
of the query, value and key input tensors must be divisible by the
latter case, the axis dimension of the query, value and key
chunk_length.
input tensors must be divisible by the chunk_length.
Returns:
Returns:
Window causal performer attention of shape `[B, T, H, out_dim]`.
Window causal performer attention of shape `[B, T, H, out_dim]`.
...
@@ -302,11 +300,13 @@ def create_projection_matrix(m, d, seed=None):
...
@@ -302,11 +300,13 @@ def create_projection_matrix(m, d, seed=None):
return
tf
.
linalg
.
matmul
(
tf
.
linalg
.
diag
(
multiplier
),
final_matrix
)
return
tf
.
linalg
.
matmul
(
tf
.
linalg
.
diag
(
multiplier
),
final_matrix
)
def
_generalized_kernel
(
x
,
projection_matrix
,
f
,
h
):
def
_generalized_kernel
(
x
,
y
,
is_query
,
projection_matrix
,
f
,
h
):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args:
Args:
x: The feature being transformed with shape [B, T, N ,H].
x: The feature being transformed with shape [B, T, N ,H].
y: The extra stats-tensor of shape [B, T, N ,H].
is_query: True if x is a query-tensor.
projection_matrix: The matrix with shape [M, H] that we projecct x to, where
projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections.
M is the number of projections.
f: A non-linear function applied on x or projected x.
f: A non-linear function applied on x or projected x.
...
@@ -316,7 +316,8 @@ def _generalized_kernel(x, projection_matrix, f, h):
...
@@ -316,7 +316,8 @@ def _generalized_kernel(x, projection_matrix, f, h):
Returns:
Returns:
Transformed feature.
Transformed feature.
"""
"""
del
y
del
is_query
if
projection_matrix
is
None
:
if
projection_matrix
is
None
:
return
h
(
x
)
*
f
(
x
)
return
h
(
x
)
*
f
(
x
)
else
:
else
:
...
@@ -475,6 +476,8 @@ _TRANSFORM_MAP = {
...
@@ -475,6 +476,8 @@ _TRANSFORM_MAP = {
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
sqrt
(
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))),
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))),
),
),
"expplus"
:
expplus
,
"identity"
:
"identity"
:
functools
.
partial
(
_generalized_kernel
,
f
=
lambda
x
:
x
,
h
=
lambda
x
:
1
)
functools
.
partial
(
_generalized_kernel
,
f
=
lambda
x
:
x
,
h
=
lambda
x
:
1
)
}
}
...
@@ -554,18 +557,16 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -554,18 +557,16 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
causal_chunk_length: Length of each chunk in tokens.
causal_chunk_length: Length of each chunk in tokens.
causal_window_length: Length of attention window in chunks.
causal_window_length: Length of attention window in chunks.
causal_window_decay: Float window decay factor or `None`. If set,
causal_window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this
exponentially decay past attention window values by this factor before
factor before summation.
summation.
causal_padding: Pad the query, value and key input tensors
causal_padding: Pad the query, value and key input tensors across the axis
across the axis from either left or right if padding is set to
from either left or right if padding is set to "left" or "right"; apply
"left" or "right"; apply no padding if padding is set to None.
no padding if padding is set to None. In the latter case, the axis
In the latter case, the axis dimension of the query, value and
dimension of the query, value and key input tensors must be divisible by
key input tensors must be divisible by the chunk_length.
the chunk_length.
**kwargs:
**kwargs: The same arguments `MultiHeadAttention` layer.
The same arguments `MultiHeadAttention` layer.
"""
"""
if
(
feature_transform
not
in
_TRANSFORM_MAP
and
if
feature_transform
not
in
_TRANSFORM_MAP
:
feature_transform
!=
"expplus"
):
raise
ValueError
(
"Unsupported feature_transform. The supported "
raise
ValueError
(
"Unsupported feature_transform. The supported "
"feature_transform are %s. "
"feature_transform are %s. "
"Got '%s'."
%
(
_TRANSFORM_MAP
.
keys
(),
feature_transform
))
"Got '%s'."
%
(
_TRANSFORM_MAP
.
keys
(),
feature_transform
))
...
@@ -661,12 +662,10 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -661,12 +662,10 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
key
*=
tf
.
math
.
sqrt
(
scale
)
key
*=
tf
.
math
.
sqrt
(
scale
)
query
*=
tf
.
math
.
sqrt
(
scale
)
query
*=
tf
.
math
.
sqrt
(
scale
)
if
feature_transform
!=
"expplus"
:
key_prime
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
query
,
False
,
key_prime
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
projection_matrix
)
projection_matrix
)
query_prime
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
projection_matrix
)
query_prime
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
key
,
True
,
else
:
projection_matrix
)
key_prime
=
expplus
(
key
,
query
,
False
,
projection_matrix
)
query_prime
=
expplus
(
query
,
key
,
True
,
projection_matrix
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
key_prime
=
tf
.
einsum
(
"BSNH,BS->BSNH"
,
key_prime
,
attention_mask
)
key_prime
=
tf
.
einsum
(
"BSNH,BS->BSNH"
,
key_prime
,
attention_mask
)
...
@@ -677,7 +676,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -677,7 +676,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output
=
tf
.
einsum
(
"BTSN,BSNH->BTNH"
,
attention_scores
,
value
)
attention_output
=
tf
.
einsum
(
"BTSN,BSNH->BTNH"
,
attention_scores
,
value
)
elif
self
.
use_causal_windowed
:
elif
self
.
use_causal_windowed
:
attention_output
=
causal_windowed_performer_attention
(
attention_output
=
causal_windowed_performer_attention
(
query_prime
,
key_prime
,
value
,
query_prime
,
key_prime
,
value
,
chunk_length
=
self
.
causal_chunk_length
,
chunk_length
=
self
.
causal_chunk_length
,
window_length
=
self
.
causal_window_length
,
window_length
=
self
.
causal_window_length
,
window_decay
=
self
.
causal_window_decay
,
window_decay
=
self
.
causal_window_decay
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment