Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d1fca260
Commit
d1fca260
authored
Jul 26, 2022
by
Avi Dubey
Committed by
A. Unique TensorFlower
Jul 26, 2022
Browse files
windowed causal performer
PiperOrigin-RevId: 463429471
parent
1db7588c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
194 additions
and
1 deletion
+194
-1
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+161
-1
official/nlp/modeling/layers/kernel_attention_test.py
official/nlp/modeling/layers/kernel_attention_test.py
+33
-0
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
d1fca260
...
...
@@ -41,6 +41,148 @@ class KernelMask(tf.keras.layers.Layer):
return
mask
def
pad_to_chunk_length
(
tensor
,
axis
,
chunk_length
,
pad
=
"right"
):
"""Pads a tensor so that shape[axis] is divisible by chunk_length.
Args:
tensor: Input tensor to pad.
axis: Axis to pad along.
chunk_length: The output tensor will have shape[axis] divisible by
chunk_length.
pad: Pad the input tensor across the axis from left if pad="left", right if
pad="right", or apply no padding if pad=None. In the latter case, the axis
dimension of the input tensor must be divisible by the chunk_length.
Returns:
Padded tensor with shape[axis] divisible by chunk_length.
"""
shape
=
tf
.
shape
(
tensor
)
rank
=
tf
.
rank
(
tensor
)
if
axis
<
0
:
axis
+=
rank
axis_length
=
shape
[
axis
]
pad_length
=
-
axis_length
%
chunk_length
if
pad
==
"right"
:
pad_width_2
=
[[
0
,
pad_length
]]
elif
pad
==
"left"
:
pad_width_2
=
[[
pad_length
,
0
]]
else
:
if
pad_length
!=
0
:
raise
ValueError
(
"When padding is not set, the axis dimension"
"has to be divisible by the chunk_length."
)
return
tensor
pad_width
=
tf
.
concat
(
[
tf
.
zeros
([
axis
,
2
],
dtype
=
tf
.
int32
),
pad_width_2
,
tf
.
zeros
([
rank
-
axis
-
1
,
2
],
dtype
=
tf
.
int32
)],
axis
=
0
)
return
tf
.
pad
(
tensor
,
pad_width
)
def
split_tensor_into_chunks
(
tensor
,
axis
,
chunk_length
):
"""Reshape tensor along given axis using chunk_length.
Args:
tensor: Input tensor.
axis: Reshape tensor along this axis.
chunk_length: Split the axis into [axis/chunk_length, chunk_length]
Returns:
Reshaped tensor.
"""
shape
=
tf
.
shape
(
tensor
)
num_chunks
=
shape
[
axis
]
//
chunk_length
new_shape
=
tf
.
concat
(
[
shape
[:
axis
],
[
num_chunks
,
chunk_length
],
shape
[(
axis
+
1
):]],
axis
=
0
)
return
tf
.
reshape
(
tensor
,
new_shape
)
def
windowed_causal_performer_attention
(
query_matrix
,
key_matrix
,
value_matrix
,
chunk_length
,
window_length
,
pad
=
"right"
):
"""Applies windowed causal kernel attention with query, key, value tensors.
We partition the T-length input sequence into N chunks, each of chunk_length
tokens (thus: T = N * chunk_length). Within each chunk, we apply bidirectional
(non-causal) Performers’ implicit attention and we model relationships between
different chunks using Performers’ causal attention. We consider windowed
causal variant of performer, where the current chunk attends only to the
window of window_length of the most recent chunks.
Below is an example with T=9, chunk_length=3, window_length=1. 1 indicates
attention is computed between the pair while 0 indicates attention is not
computed between the pairs:
111000000
111000000
111000000
111111000
111111000
111111000
000111111
000111111
000111111
User can ensure sequence_length is divisible by chunk_length or use
pad="left"/"right" to pad the sequence length either at the top or bottom
respectively and make it divisible by chunk_length.
Args:
query_matrix: Kernel query `Tensor` of shape `[B, T, N, dim]`.
key_matrix: Kernel key `Tensor` of shape `[B, T, N, dim]`.
value_matrix: Value `Tensor` of shape `[B, T, N, out_dim]`.
chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks.
pad: Pad the query, value and key input tensors across the T dimension from
left if pad="left", right if pad="right", or apply no padding if pad=None.
In the latter case, the T dimension of the input tensors must be divisible
by the chunk_length.
Returns:
Window causal performer attention of shape `[B, T, N, out_dim]`.
"""
old_shape
=
tf
.
shape
(
value_matrix
)
query_matrix
=
pad_to_chunk_length
(
query_matrix
,
-
3
,
chunk_length
,
pad
)
key_matrix
=
pad_to_chunk_length
(
key_matrix
,
-
3
,
chunk_length
,
pad
)
value_matrix
=
pad_to_chunk_length
(
value_matrix
,
-
3
,
chunk_length
,
pad
)
new_shape
=
tf
.
shape
(
value_matrix
)
chunked_query_matrix
=
split_tensor_into_chunks
(
query_matrix
,
-
3
,
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, dim]
chunked_key_matrix
=
split_tensor_into_chunks
(
key_matrix
,
-
3
,
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, dim]
chunked_value_matrix
=
split_tensor_into_chunks
(
value_matrix
,
-
3
,
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, out_dim]
kp_v
=
tf
.
einsum
(
"BNCHD,BNCHO->BNHDO"
,
chunked_key_matrix
,
chunked_value_matrix
)
kp_v_cumsum
=
tf
.
cumsum
(
kp_v
,
axis
=-
4
)
kp_v_winsum
=
kp_v_cumsum
-
tf
.
pad
(
kp_v_cumsum
,
[[
0
,
0
],
[
window_length
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]])[:,
:
-
window_length
]
numerator
=
tf
.
einsum
(
"BNCHD,BNHDO->BNCHO"
,
chunked_query_matrix
,
kp_v_winsum
)
k_sum
=
tf
.
reduce_sum
(
chunked_key_matrix
,
axis
=-
3
)
k_cumsum
=
tf
.
cumsum
(
k_sum
,
axis
=-
3
)
k_winsum
=
k_cumsum
-
tf
.
pad
(
k_cumsum
,
[[
0
,
0
],
[
window_length
,
0
],
[
0
,
0
],
[
0
,
0
]])[:,
:
-
window_length
]
denominator
=
tf
.
einsum
(
"BNCHD,BNHD->BNCH"
,
chunked_query_matrix
,
k_winsum
)
denominator
=
tf
.
expand_dims
(
denominator
,
-
1
)
+
_NUMERIC_STABLER
attention
=
numerator
/
denominator
attention
=
tf
.
reshape
(
attention
,
new_shape
)
start
=
tf
.
zeros
([
len
(
old_shape
)],
dtype
=
old_shape
.
dtype
)
attention
=
tf
.
slice
(
attention
,
start
,
old_shape
)
return
attention
def
create_projection_matrix
(
m
,
d
,
seed
=
None
):
r
"""Constructs the matrix of random projections.
...
...
@@ -304,6 +446,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
begin_kernel
=
0
,
scale
=
None
,
scale_by_length
=
False
,
use_windowed_causal
=
False
,
chunk_length
=
1
,
window_length
=
3
,
**
kwargs
):
r
"""Constructor of KernelAttention.
...
...
@@ -330,9 +475,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
the dot product based on key length. Set as log_512^(n) to stablize
attention entropy against length. Refer to
https://kexue.fm/archives/8823 for details.
use_windowed_causal: If true perform windowed causal attention. See
windowed_causal_performer_attention function docstring for more details.
chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks.
**kwargs: The same arguments `MultiHeadAttention` layer.
"""
if
feature_transform
not
in
_TRANSFORM_MAP
and
feature_transform
!=
"expplus"
:
if
(
feature_transform
not
in
_TRANSFORM_MAP
and
feature_transform
!=
"expplus"
):
raise
ValueError
(
"Unsupported feature_transform. The supported "
"feature_transform are %s. "
"Got '%s'."
%
(
_TRANSFORM_MAP
.
keys
(),
feature_transform
))
...
...
@@ -359,6 +509,12 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self
.
_projection_matrix
=
create_projection_matrix
(
self
.
_num_random_features
,
self
.
_key_dim
,
tf
.
constant
([
self
.
_seed
,
self
.
_seed
+
1
]))
self
.
use_windowed_causal
=
use_windowed_causal
self
.
chunk_length
=
chunk_length
self
.
window_length
=
window_length
if
self
.
use_windowed_causal
and
self
.
_is_short_seq
:
raise
ValueError
(
"use_windowed_causal and short_seq methods are mutually exclusive"
)
def
_compute_attention
(
self
,
query
,
...
...
@@ -394,6 +550,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output: Multi-headed outputs of attention computation.
"""
projection_matrix
=
None
if
self
.
_num_random_features
>
0
:
if
self
.
_redraw
and
training
:
projection_matrix
=
create_projection_matrix
(
self
.
_num_random_features
,
...
...
@@ -433,6 +590,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_scores
=
tf
.
einsum
(
"BTNH,BSNH->BTSN"
,
query_prime
,
key_prime
)
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
axis
=
2
)
attention_output
=
tf
.
einsum
(
"BTSN,BSNH->BTNH"
,
attention_scores
,
value
)
elif
self
.
use_windowed_causal
:
attention_output
=
windowed_causal_performer_attention
(
query_prime
,
key_prime
,
value
,
self
.
chunk_length
,
self
.
window_length
)
else
:
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key_prime
,
value
)
denominator
=
1.0
/
(
...
...
official/nlp/modeling/layers/kernel_attention_test.py
View file @
d1fca260
...
...
@@ -60,6 +60,39 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
training
=
training
)
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
[
0
]))
def
test_windowed_causal_attention_projection
(
self
,
feature_transform
,
num_random_features
,
training
,
redraw
,
begin_kernel
):
num_heads
=
12
key_dim
=
64
seq_length
=
1024
batch_size
=
2
test_layer
=
attention
.
KernelAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
feature_transform
=
feature_transform
,
num_random_features
=
num_random_features
,
redraw
=
redraw
,
is_short_seq
=
False
,
begin_kernel
=
begin_kernel
,
use_windowed_causal
=
True
,
chunk_length
=
8
,
window_length
=
3
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
value
=
query
encoder_inputs_mask
=
tf
.
zeros
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
masks
=
tf
.
cast
(
encoder_inputs_mask
,
dtype
=
tf
.
float32
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
,
training
=
training
)
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
0
],
_TRAINING
,
[
False
],
_IS_SHORT_SEQ
,
_BEGIN_KERNEL
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment