Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
10673875
Commit
10673875
authored
Sep 27, 2022
by
Frederick Liu
Committed by
A. Unique TensorFlower
Sep 27, 2022
Browse files
[kernel] Add streaming support.
PiperOrigin-RevId: 477214841
parent
798d318f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
158 additions
and
53 deletions
+158
-53
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+100
-52
official/nlp/modeling/layers/kernel_attention_test.py
official/nlp/modeling/layers/kernel_attention_test.py
+58
-1
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
10673875
...
@@ -160,7 +160,8 @@ def causal_windowed_performer_attention(query_matrix,
...
@@ -160,7 +160,8 @@ def causal_windowed_performer_attention(query_matrix,
chunk_length
,
chunk_length
,
window_length
,
window_length
,
window_decay
=
None
,
window_decay
=
None
,
padding
=
None
):
padding
=
None
,
cache
=
None
):
"""Applies windowed causal kernel attention with query, key, value tensors.
"""Applies windowed causal kernel attention with query, key, value tensors.
We partition the T-length input sequence into N chunks, each of
We partition the T-length input sequence into N chunks, each of
...
@@ -202,10 +203,13 @@ def causal_windowed_performer_attention(query_matrix,
...
@@ -202,10 +203,13 @@ def causal_windowed_performer_attention(query_matrix,
padding if padding is set to None. In the latter case, the axis dimension
padding if padding is set to None. In the latter case, the axis dimension
of the query, value and key input tensors must be divisible by the
of the query, value and key input tensors must be divisible by the
chunk_length.
chunk_length.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
Returns:
Returns:
Window causal performer attention of shape `[B, T, H, out_dim]`.
Window causal performer attention of shape `[B, T, H, out_dim]`.
"""
"""
if
cache
is
None
:
# Training
old_shape
=
tf
.
shape
(
value_matrix
)
old_shape
=
tf
.
shape
(
value_matrix
)
query_matrix
=
pad_to_chunk_length
(
query_matrix
,
-
3
,
chunk_length
,
padding
)
query_matrix
=
pad_to_chunk_length
(
query_matrix
,
-
3
,
chunk_length
,
padding
)
...
@@ -239,18 +243,30 @@ def causal_windowed_performer_attention(query_matrix,
...
@@ -239,18 +243,30 @@ def causal_windowed_performer_attention(query_matrix,
kp_v_winsum
=
weighted_window_sum
(
kp_v
,
window_length
,
decaying_weights
)
kp_v_winsum
=
weighted_window_sum
(
kp_v
,
window_length
,
decaying_weights
)
k_winsum
=
weighted_window_sum
(
k_sum
,
window_length
,
decaying_weights
)
k_winsum
=
weighted_window_sum
(
k_sum
,
window_length
,
decaying_weights
)
numerator
=
tf
.
einsum
(
"BTCHD,BTHDO->BTCHO"
,
chunked_query_matrix
,
kp_v_winsum
)
numerator
=
tf
.
einsum
(
"BTCHD,BTHDO->BTCHO"
,
chunked_query_matrix
,
kp_v_winsum
)
k_winsum
=
tf
.
squeeze
(
k_winsum
,
-
3
)
k_winsum
=
tf
.
squeeze
(
k_winsum
,
-
3
)
denominator
=
tf
.
einsum
(
"BTCHD,BTHD->BTCH"
,
chunked_query_matrix
,
k_winsum
)
denominator
=
tf
.
einsum
(
"BTCHD,BTHD->BTCH"
,
chunked_query_matrix
,
k_winsum
)
denominator
=
tf
.
expand_dims
(
denominator
,
-
1
)
+
_NUMERIC_STABLER
denominator
=
tf
.
expand_dims
(
denominator
,
-
1
)
+
_NUMERIC_STABLER
attention
=
numerator
/
denominator
attention
=
numerator
/
denominator
attention
=
tf
.
reshape
(
attention
,
new_shape
)
attention
=
tf
.
reshape
(
attention
,
new_shape
)
start
=
tf
.
zeros
([
len
(
old_shape
)],
dtype
=
old_shape
.
dtype
)
start
=
tf
.
zeros
([
len
(
old_shape
)],
dtype
=
old_shape
.
dtype
)
attention
=
tf
.
slice
(
attention
,
start
,
old_shape
)
attention
=
tf
.
slice
(
attention
,
start
,
old_shape
)
# Queued window cache (drop instead of decay) not yet supported.
else
:
# Streaming
if
window_decay
is
None
or
window_decay
>
1.0
or
window_decay
<
0.0
:
raise
ValueError
(
"window_decay should be in (0.0, 1.0) and not None."
)
kv
=
cache
[
"kv"
]
+
tf
.
einsum
(
"BTHD,BTHO->BHOD"
,
key_matrix
,
value_matrix
)
cache
[
"kv"
]
=
kv
*
window_decay
k_sum
=
cache
[
"k_sum"
]
+
tf
.
reduce_sum
(
key_matrix
,
axis
=
1
)
cache
[
"k_sum"
]
=
k_sum
*
window_decay
denominator
=
tf
.
einsum
(
"BTHD,BHD->BTH"
,
query_matrix
,
k_sum
)
attention
=
tf
.
einsum
(
"BTHD,BHOD,BTH->BTHO"
,
query_matrix
,
kv
,
1.0
/
(
denominator
+
_NUMERIC_STABLER
))
return
attention
return
attention
...
@@ -443,7 +459,7 @@ def expplus(data_orig,
...
@@ -443,7 +459,7 @@ def expplus(data_orig,
# pylint: disable=g-long-lambda
# pylint: disable=g-long-lambda
_TRANSFORM_MAP
=
{
_CAUSAL_SUPPORT
_TRANSFORM_MAP
=
{
"elu"
:
"elu"
:
functools
.
partial
(
functools
.
partial
(
_generalized_kernel
,
_generalized_kernel
,
...
@@ -476,11 +492,19 @@ _TRANSFORM_MAP = {
...
@@ -476,11 +492,19 @@ _TRANSFORM_MAP = {
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
sqrt
(
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))),
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))),
),
),
"expplus"
:
expplus
,
"identity"
:
"identity"
:
functools
.
partial
(
_generalized_kernel
,
f
=
lambda
x
:
x
,
h
=
lambda
x
:
1
)
functools
.
partial
(
_generalized_kernel
,
f
=
lambda
x
:
x
,
h
=
lambda
x
:
1
)
}
}
_NON_CAUSAL_SUPPORT_TRANSFORM_MAP
=
{
"expplus"
:
expplus
,
}
_TRANSFORM_MAP
=
{
**
_CAUSAL_SUPPORT_TRANSFORM_MAP
,
**
_NON_CAUSAL_SUPPORT_TRANSFORM_MAP
}
# pylint: enable=g-long-lambda
# pylint: enable=g-long-lambda
...
@@ -609,6 +633,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -609,6 +633,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
feature_transform
,
feature_transform
,
is_short_seq
,
is_short_seq
,
attention_mask
=
None
,
attention_mask
=
None
,
cache
=
None
,
training
=
False
,
training
=
False
,
numeric_stabler
=
_NUMERIC_STABLER
):
numeric_stabler
=
_NUMERIC_STABLER
):
"""Applies kernel attention with query, key, value tensors.
"""Applies kernel attention with query, key, value tensors.
...
@@ -628,6 +653,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -628,6 +653,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User
to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads.
may want to mask the output if query contains pads.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
training: Python boolean indicating whether the layer should behave in
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
training mode (adding dropout) or in inference mode (doing nothing).
numeric_stabler: A scalar value added to avoid divide by 0.
numeric_stabler: A scalar value added to avoid divide by 0.
...
@@ -682,7 +709,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -682,7 +709,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
chunk_length
=
self
.
causal_chunk_length
,
chunk_length
=
self
.
causal_chunk_length
,
window_length
=
self
.
causal_window_length
,
window_length
=
self
.
causal_window_length
,
window_decay
=
self
.
causal_window_decay
,
window_decay
=
self
.
causal_window_decay
,
padding
=
self
.
causal_padding
)
padding
=
self
.
causal_padding
,
cache
=
cache
)
else
:
else
:
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key_prime
,
value
)
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key_prime
,
value
)
denominator
=
1.0
/
(
denominator
=
1.0
/
(
...
@@ -709,7 +737,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -709,7 +737,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
name
=
"attention_output_softmax"
)
name
=
"attention_output_softmax"
)
self
.
_dropout_softmax
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
self
.
_dropout_softmax
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
def
call
(
self
,
query
,
value
,
key
=
None
,
attention_mask
=
None
,
training
=
False
):
def
call
(
self
,
query
,
value
,
key
=
None
,
attention_mask
=
None
,
cache
=
None
,
training
=
False
):
"""Compute attention with kernel mechanism.
"""Compute attention with kernel mechanism.
Args:
Args:
...
@@ -720,12 +749,29 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -720,12 +749,29 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User
to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads.
may want to mask the output if query contains pads.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
training: Python boolean indicating whether the layer should behave in
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
Returns:
Multi-headed outputs of attention computation.
Multi-headed outputs of attention computation.
"""
"""
if
cache
is
not
None
:
if
training
:
raise
ValueError
(
"Cache is not supported when training is True."
)
if
not
self
.
use_causal_windowed
:
raise
ValueError
(
"Cache is not supported for non use_causal_windowed case."
)
if
self
.
_begin_kernel
:
raise
ValueError
(
"Cache is not supported when begin_kernel is set since the bahvior "
"is too complicated."
)
if
self
.
_feature_transform
in
_NON_CAUSAL_SUPPORT_TRANSFORM_MAP
:
raise
ValueError
(
"Cache is not supported for feature_transform %s"
%
(
self
.
_feature_transform
))
if
not
self
.
_built_from_signature
:
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
self
.
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
if
key
is
None
:
if
key
is
None
:
...
@@ -761,7 +807,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -761,7 +807,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output
=
self
.
_compute_attention
(
query
,
key
,
value
,
attention_output
=
self
.
_compute_attention
(
query
,
key
,
value
,
self
.
_feature_transform
,
self
.
_feature_transform
,
self
.
_is_short_seq
,
self
.
_is_short_seq
,
attention_mask
,
training
)
attention_mask
,
cache
,
training
)
# This is actually dropping out entire tokens to attend to, which might
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# seem a bit unusual, but is taken from the original Transformer paper.
attention_output
=
self
.
_dropout_layer
(
attention_output
)
attention_output
=
self
.
_dropout_layer
(
attention_output
)
...
...
official/nlp/modeling/layers/kernel_attention_test.py
View file @
10673875
...
@@ -30,6 +30,64 @@ _BEGIN_KERNEL = [0, 512]
...
@@ -30,6 +30,64 @@ _BEGIN_KERNEL = [0, 512]
class
KernelAttentionTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
KernelAttentionTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
# expplus is only designed for bi-directional use case.
# exp can be numeric unstable.
@
parameterized
.
parameters
(
itertools
.
product
(
[
"relu"
,
"elu"
],
[
1
,
4
],
[
0.9
]))
def
test_causal_windowed_attention_projection_streaming
(
self
,
feature_transform
,
causal_chunk_length
,
causal_weight_decay
):
num_heads
=
12
key_dim
=
64
seq_length
=
16
num_chunks
=
seq_length
//
causal_chunk_length
causal_window_length
=
num_chunks
batch_size
=
2
training
=
False
num_random_features
=
0
test_layer
=
attention
.
KernelAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
feature_transform
=
feature_transform
,
num_random_features
=
num_random_features
,
redraw
=
False
,
is_short_seq
=
False
,
begin_kernel
=
False
,
use_causal_windowed
=
True
,
causal_chunk_length
=
causal_chunk_length
,
causal_window_length
=
causal_window_length
,
causal_window_decay
=
causal_weight_decay
,
causal_padding
=
None
,
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
),
seed
=
2
)
value
=
query
encoder_inputs_mask
=
tf
.
ones
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
masks
=
tf
.
cast
(
encoder_inputs_mask
,
dtype
=
tf
.
float32
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
,
training
=
training
)
kv_cache
=
tf
.
zeros
(
(
batch_size
,
num_heads
,
key_dim
,
num_random_features
if
num_random_features
>
0
else
key_dim
))
k_sum_cache
=
tf
.
zeros
((
batch_size
,
1
,
key_dim
))
stream_output
=
[]
cache
=
{
"kv"
:
kv_cache
,
"k_sum"
:
k_sum_cache
}
for
i
in
range
(
num_chunks
):
stream_output
.
append
(
test_layer
(
query
=
query
[:,
i
*
causal_chunk_length
:(
i
+
1
)
*
causal_chunk_length
,
:],
value
=
value
[:,
i
*
causal_chunk_length
:(
i
+
1
)
*
causal_chunk_length
,
:],
attention_mask
=
masks
[:,
i
*
causal_chunk_length
:(
i
+
1
)
*
causal_chunk_length
],
cache
=
cache
,
training
=
training
))
stream_output
=
tf
.
concat
(
stream_output
,
axis
=
1
)
self
.
assertAllClose
(
output
,
stream_output
)
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
itertools
.
product
(
_FEATURE_TRANSFORM
,
[
127
],
_TRAINING
,
[
True
,
False
],
_IS_SHORT_SEQ
,
_BEGIN_KERNEL
))
_IS_SHORT_SEQ
,
_BEGIN_KERNEL
))
...
@@ -196,6 +254,5 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -196,6 +254,5 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
[
2
,
1
,
2
,
2
,
2
]),
[
2
,
1
,
2
,
2
,
2
]),
winsum
)
winsum
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment