Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
10673875
Commit
10673875
authored
Sep 27, 2022
by
Frederick Liu
Committed by
A. Unique TensorFlower
Sep 27, 2022
Browse files
[kernel] Add streaming support.
PiperOrigin-RevId: 477214841
parent
798d318f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
158 additions
and
53 deletions
+158
-53
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+100
-52
official/nlp/modeling/layers/kernel_attention_test.py
official/nlp/modeling/layers/kernel_attention_test.py
+58
-1
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
10673875
...
...
@@ -160,7 +160,8 @@ def causal_windowed_performer_attention(query_matrix,
chunk_length
,
window_length
,
window_decay
=
None
,
padding
=
None
):
padding
=
None
,
cache
=
None
):
"""Applies windowed causal kernel attention with query, key, value tensors.
We partition the T-length input sequence into N chunks, each of
...
...
@@ -202,55 +203,70 @@ def causal_windowed_performer_attention(query_matrix,
padding if padding is set to None. In the latter case, the axis dimension
of the query, value and key input tensors must be divisible by the
chunk_length.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
Returns:
Window causal performer attention of shape `[B, T, H, out_dim]`.
"""
old_shape
=
tf
.
shape
(
value_matrix
)
query_matrix
=
pad_to_chunk_length
(
query_matrix
,
-
3
,
chunk_length
,
padding
)
key_matrix
=
pad_to_chunk_length
(
key_matrix
,
-
3
,
chunk_length
,
padding
)
value_matrix
=
pad_to_chunk_length
(
value_matrix
,
-
3
,
chunk_length
,
padding
)
new_shape
=
tf
.
shape
(
value_matrix
)
chunked_query_matrix
=
split_tensor_into_chunks
(
query_matrix
,
-
3
,
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, dim]
chunked_key_matrix
=
split_tensor_into_chunks
(
key_matrix
,
-
3
,
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, dim]
chunked_value_matrix
=
split_tensor_into_chunks
(
value_matrix
,
-
3
,
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, out_dim]
kp_v
=
tf
.
einsum
(
"BTCHD,BTCHO->BTHDO"
,
chunked_key_matrix
,
chunked_value_matrix
)
k_sum
=
tf
.
math
.
reduce_sum
(
chunked_key_matrix
,
axis
=-
3
,
keepdims
=
True
)
if
window_decay
is
None
:
kp_v_winsum
=
rectangular_window_sum
(
kp_v
,
window_length
)
k_winsum
=
rectangular_window_sum
(
k_sum
,
window_length
)
else
:
# Compute exponentially decaying weights.
decaying_weights
=
tf
.
math
.
pow
(
tf
.
convert_to_tensor
(
window_decay
,
dtype
=
value_matrix
.
dtype
),
tf
.
range
(
window_length
-
1
,
-
1
,
delta
=-
1
,
dtype
=
value_matrix
.
dtype
))
kp_v_winsum
=
weighted_window_sum
(
kp_v
,
window_length
,
decaying_weights
)
k_winsum
=
weighted_window_sum
(
k_sum
,
window_length
,
decaying_weights
)
numerator
=
tf
.
einsum
(
"BTCHD,BTHDO->BTCHO"
,
chunked_query_matrix
,
kp_v_winsum
)
k_winsum
=
tf
.
squeeze
(
k_winsum
,
-
3
)
denominator
=
tf
.
einsum
(
"BTCHD,BTHD->BTCH"
,
chunked_query_matrix
,
k_winsum
)
denominator
=
tf
.
expand_dims
(
denominator
,
-
1
)
+
_NUMERIC_STABLER
attention
=
numerator
/
denominator
attention
=
tf
.
reshape
(
attention
,
new_shape
)
start
=
tf
.
zeros
([
len
(
old_shape
)],
dtype
=
old_shape
.
dtype
)
attention
=
tf
.
slice
(
attention
,
start
,
old_shape
)
if
cache
is
None
:
# Training
old_shape
=
tf
.
shape
(
value_matrix
)
query_matrix
=
pad_to_chunk_length
(
query_matrix
,
-
3
,
chunk_length
,
padding
)
key_matrix
=
pad_to_chunk_length
(
key_matrix
,
-
3
,
chunk_length
,
padding
)
value_matrix
=
pad_to_chunk_length
(
value_matrix
,
-
3
,
chunk_length
,
padding
)
new_shape
=
tf
.
shape
(
value_matrix
)
chunked_query_matrix
=
split_tensor_into_chunks
(
query_matrix
,
-
3
,
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, dim]
chunked_key_matrix
=
split_tensor_into_chunks
(
key_matrix
,
-
3
,
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, dim]
chunked_value_matrix
=
split_tensor_into_chunks
(
value_matrix
,
-
3
,
chunk_length
)
# [-1, T//chunk_length, chunk_length, N, out_dim]
kp_v
=
tf
.
einsum
(
"BTCHD,BTCHO->BTHDO"
,
chunked_key_matrix
,
chunked_value_matrix
)
k_sum
=
tf
.
math
.
reduce_sum
(
chunked_key_matrix
,
axis
=-
3
,
keepdims
=
True
)
if
window_decay
is
None
:
kp_v_winsum
=
rectangular_window_sum
(
kp_v
,
window_length
)
k_winsum
=
rectangular_window_sum
(
k_sum
,
window_length
)
else
:
# Compute exponentially decaying weights.
decaying_weights
=
tf
.
math
.
pow
(
tf
.
convert_to_tensor
(
window_decay
,
dtype
=
value_matrix
.
dtype
),
tf
.
range
(
window_length
-
1
,
-
1
,
delta
=-
1
,
dtype
=
value_matrix
.
dtype
))
kp_v_winsum
=
weighted_window_sum
(
kp_v
,
window_length
,
decaying_weights
)
k_winsum
=
weighted_window_sum
(
k_sum
,
window_length
,
decaying_weights
)
numerator
=
tf
.
einsum
(
"BTCHD,BTHDO->BTCHO"
,
chunked_query_matrix
,
kp_v_winsum
)
k_winsum
=
tf
.
squeeze
(
k_winsum
,
-
3
)
denominator
=
tf
.
einsum
(
"BTCHD,BTHD->BTCH"
,
chunked_query_matrix
,
k_winsum
)
denominator
=
tf
.
expand_dims
(
denominator
,
-
1
)
+
_NUMERIC_STABLER
attention
=
numerator
/
denominator
attention
=
tf
.
reshape
(
attention
,
new_shape
)
start
=
tf
.
zeros
([
len
(
old_shape
)],
dtype
=
old_shape
.
dtype
)
attention
=
tf
.
slice
(
attention
,
start
,
old_shape
)
# Queued window cache (drop instead of decay) not yet supported.
else
:
# Streaming
if
window_decay
is
None
or
window_decay
>
1.0
or
window_decay
<
0.0
:
raise
ValueError
(
"window_decay should be in (0.0, 1.0) and not None."
)
kv
=
cache
[
"kv"
]
+
tf
.
einsum
(
"BTHD,BTHO->BHOD"
,
key_matrix
,
value_matrix
)
cache
[
"kv"
]
=
kv
*
window_decay
k_sum
=
cache
[
"k_sum"
]
+
tf
.
reduce_sum
(
key_matrix
,
axis
=
1
)
cache
[
"k_sum"
]
=
k_sum
*
window_decay
denominator
=
tf
.
einsum
(
"BTHD,BHD->BTH"
,
query_matrix
,
k_sum
)
attention
=
tf
.
einsum
(
"BTHD,BHOD,BTH->BTHO"
,
query_matrix
,
kv
,
1.0
/
(
denominator
+
_NUMERIC_STABLER
))
return
attention
...
...
@@ -443,7 +459,7 @@ def expplus(data_orig,
# pylint: disable=g-long-lambda
_TRANSFORM_MAP
=
{
_CAUSAL_SUPPORT
_TRANSFORM_MAP
=
{
"elu"
:
functools
.
partial
(
_generalized_kernel
,
...
...
@@ -476,11 +492,19 @@ _TRANSFORM_MAP = {
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))),
),
"expplus"
:
expplus
,
"identity"
:
functools
.
partial
(
_generalized_kernel
,
f
=
lambda
x
:
x
,
h
=
lambda
x
:
1
)
}
_NON_CAUSAL_SUPPORT_TRANSFORM_MAP
=
{
"expplus"
:
expplus
,
}
_TRANSFORM_MAP
=
{
**
_CAUSAL_SUPPORT_TRANSFORM_MAP
,
**
_NON_CAUSAL_SUPPORT_TRANSFORM_MAP
}
# pylint: enable=g-long-lambda
...
...
@@ -609,6 +633,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
feature_transform
,
is_short_seq
,
attention_mask
=
None
,
cache
=
None
,
training
=
False
,
numeric_stabler
=
_NUMERIC_STABLER
):
"""Applies kernel attention with query, key, value tensors.
...
...
@@ -628,6 +653,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
numeric_stabler: A scalar value added to avoid divide by 0.
...
...
@@ -682,7 +709,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
chunk_length
=
self
.
causal_chunk_length
,
window_length
=
self
.
causal_window_length
,
window_decay
=
self
.
causal_window_decay
,
padding
=
self
.
causal_padding
)
padding
=
self
.
causal_padding
,
cache
=
cache
)
else
:
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key_prime
,
value
)
denominator
=
1.0
/
(
...
...
@@ -709,7 +737,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
name
=
"attention_output_softmax"
)
self
.
_dropout_softmax
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
def
call
(
self
,
query
,
value
,
key
=
None
,
attention_mask
=
None
,
training
=
False
):
def
call
(
self
,
query
,
value
,
key
=
None
,
attention_mask
=
None
,
cache
=
None
,
training
=
False
):
"""Compute attention with kernel mechanism.
Args:
...
...
@@ -720,12 +749,29 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
Multi-headed outputs of attention computation.
"""
if
cache
is
not
None
:
if
training
:
raise
ValueError
(
"Cache is not supported when training is True."
)
if
not
self
.
use_causal_windowed
:
raise
ValueError
(
"Cache is not supported for non use_causal_windowed case."
)
if
self
.
_begin_kernel
:
raise
ValueError
(
"Cache is not supported when begin_kernel is set since the bahvior "
"is too complicated."
)
if
self
.
_feature_transform
in
_NON_CAUSAL_SUPPORT_TRANSFORM_MAP
:
raise
ValueError
(
"Cache is not supported for feature_transform %s"
%
(
self
.
_feature_transform
))
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
if
key
is
None
:
...
...
@@ -761,7 +807,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output
=
self
.
_compute_attention
(
query
,
key
,
value
,
self
.
_feature_transform
,
self
.
_is_short_seq
,
attention_mask
,
training
)
attention_mask
,
cache
,
training
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_output
=
self
.
_dropout_layer
(
attention_output
)
...
...
official/nlp/modeling/layers/kernel_attention_test.py
View file @
10673875
...
...
@@ -30,6 +30,64 @@ _BEGIN_KERNEL = [0, 512]
class
KernelAttentionTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
# expplus is only designed for bi-directional use case.
# exp can be numeric unstable.
@
parameterized
.
parameters
(
itertools
.
product
(
[
"relu"
,
"elu"
],
[
1
,
4
],
[
0.9
]))
def
test_causal_windowed_attention_projection_streaming
(
self
,
feature_transform
,
causal_chunk_length
,
causal_weight_decay
):
num_heads
=
12
key_dim
=
64
seq_length
=
16
num_chunks
=
seq_length
//
causal_chunk_length
causal_window_length
=
num_chunks
batch_size
=
2
training
=
False
num_random_features
=
0
test_layer
=
attention
.
KernelAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
feature_transform
=
feature_transform
,
num_random_features
=
num_random_features
,
redraw
=
False
,
is_short_seq
=
False
,
begin_kernel
=
False
,
use_causal_windowed
=
True
,
causal_chunk_length
=
causal_chunk_length
,
causal_window_length
=
causal_window_length
,
causal_window_decay
=
causal_weight_decay
,
causal_padding
=
None
,
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
),
seed
=
2
)
value
=
query
encoder_inputs_mask
=
tf
.
ones
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
masks
=
tf
.
cast
(
encoder_inputs_mask
,
dtype
=
tf
.
float32
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
,
training
=
training
)
kv_cache
=
tf
.
zeros
(
(
batch_size
,
num_heads
,
key_dim
,
num_random_features
if
num_random_features
>
0
else
key_dim
))
k_sum_cache
=
tf
.
zeros
((
batch_size
,
1
,
key_dim
))
stream_output
=
[]
cache
=
{
"kv"
:
kv_cache
,
"k_sum"
:
k_sum_cache
}
for
i
in
range
(
num_chunks
):
stream_output
.
append
(
test_layer
(
query
=
query
[:,
i
*
causal_chunk_length
:(
i
+
1
)
*
causal_chunk_length
,
:],
value
=
value
[:,
i
*
causal_chunk_length
:(
i
+
1
)
*
causal_chunk_length
,
:],
attention_mask
=
masks
[:,
i
*
causal_chunk_length
:(
i
+
1
)
*
causal_chunk_length
],
cache
=
cache
,
training
=
training
))
stream_output
=
tf
.
concat
(
stream_output
,
axis
=
1
)
self
.
assertAllClose
(
output
,
stream_output
)
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
_IS_SHORT_SEQ
,
_BEGIN_KERNEL
))
...
...
@@ -196,6 +254,5 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
[
2
,
1
,
2
,
2
,
2
]),
winsum
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment