Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0b579232
Commit
0b579232
authored
Nov 25, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Nov 25, 2019
Browse files
Implement CachedAttention layer. This is useful for decoders.
PiperOrigin-RevId: 282439719
parent
bcce419a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
150 additions
and
3 deletions
+150
-3
official/nlp/modeling/layers/__init__.py
official/nlp/modeling/layers/__init__.py
+1
-1
official/nlp/modeling/layers/attention.py
official/nlp/modeling/layers/attention.py
+83
-1
official/nlp/modeling/layers/attention_test.py
official/nlp/modeling/layers/attention_test.py
+66
-1
No files found.
official/nlp/modeling/layers/__init__.py
View file @
0b579232
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Layers package definition."""
"""Layers package definition."""
from
official.nlp.modeling.layers.attention
import
Attention
from
official.nlp.modeling.layers.attention
import
*
# pylint: disable=wildcard-import
from
official.nlp.modeling.layers.dense_einsum
import
DenseEinsum
from
official.nlp.modeling.layers.dense_einsum
import
DenseEinsum
from
official.nlp.modeling.layers.masked_softmax
import
MaskedSoftmax
from
official.nlp.modeling.layers.masked_softmax
import
MaskedSoftmax
from
official.nlp.modeling.layers.on_device_embedding
import
OnDeviceEmbedding
from
official.nlp.modeling.layers.on_device_embedding
import
OnDeviceEmbedding
...
...
official/nlp/modeling/layers/attention.py
View file @
0b579232
...
@@ -119,7 +119,7 @@ class Attention(tf.keras.layers.Layer):
...
@@ -119,7 +119,7 @@ class Attention(tf.keras.layers.Layer):
self
.
_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
def
compute_output_shape
(
self
,
input_shape
):
def
compute_output_shape
(
self
,
input_shape
):
# TODO(momernick): validate tensor dimensio
os
# TODO(momernick): validate tensor dimensio
ns.
from_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
from_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
batch
=
from_tensor_shape
[
0
]
batch
=
from_tensor_shape
[
0
]
from_tensor_length
=
from_tensor_shape
[
1
]
from_tensor_length
=
from_tensor_shape
[
1
]
...
@@ -188,3 +188,85 @@ class Attention(tf.keras.layers.Layer):
...
@@ -188,3 +188,85 @@ class Attention(tf.keras.layers.Layer):
# `context_layer` = [B, F, N, H]
# `context_layer` = [B, F, N, H]
return
tf
.
einsum
(
"BNFT,BTNH->BFNH"
,
attention_probs
,
value_tensor
)
return
tf
.
einsum
(
"BNFT,BTNH->BFNH"
,
attention_probs
,
value_tensor
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
CachedAttention
(
Attention
):
"""Attention layer with cache used for auto-agressive decoding.
Attributes:
num_heads: Number of attention heads.
head_size: Size of each attention head.
**kwargs: Other keyword arguments inherit from `Attention` class.
"""
def
__init__
(
self
,
num_heads
,
head_size
,
**
kwargs
):
super
(
CachedAttention
,
self
).
__init__
(
num_heads
,
head_size
,
**
kwargs
)
def
_update_cache
(
self
,
key_tensor
,
value_tensor
,
cache
,
decode_loop_step
):
"""Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values.
if
decode_loop_step
is
not
None
:
# TPU special case.
key_seq_dim
=
cache
[
"key"
].
shape
.
as_list
()[
1
]
indices
=
tf
.
reshape
(
tf
.
one_hot
(
decode_loop_step
,
key_seq_dim
,
dtype
=
key_tensor
.
dtype
),
[
1
,
key_seq_dim
,
1
,
1
])
key_tensor
=
cache
[
"key"
]
+
key_tensor
*
indices
value_seq_dim
=
cache
[
"value"
].
shape
.
as_list
()[
1
]
indices
=
tf
.
reshape
(
tf
.
one_hot
(
decode_loop_step
,
value_seq_dim
,
dtype
=
value_tensor
.
dtype
),
[
1
,
value_seq_dim
,
1
,
1
])
value_tensor
=
cache
[
"value"
]
+
value_tensor
*
indices
else
:
key_tensor
=
tf
.
concat
(
[
tf
.
cast
(
cache
[
"key"
],
key_tensor
.
dtype
),
key_tensor
],
axis
=
1
)
value_tensor
=
tf
.
concat
(
[
tf
.
cast
(
cache
[
"value"
],
value_tensor
.
dtype
),
value_tensor
],
axis
=
1
)
# Update cache
cache
[
"key"
]
=
key_tensor
cache
[
"value"
]
=
value_tensor
return
key_tensor
,
value_tensor
def
call
(
self
,
inputs
,
decode_loop_step
=
None
):
from_tensor
=
inputs
[
0
]
to_tensor
=
inputs
[
1
]
attention_mask
=
inputs
[
2
]
if
len
(
inputs
)
>=
3
else
None
cache
=
inputs
[
3
]
if
len
(
inputs
)
>=
4
else
None
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor
=
self
.
_query_dense
(
from_tensor
)
# `key_tensor` = [B, T, N, H]
key_tensor
=
self
.
_key_dense
(
to_tensor
)
# `value_tensor` = [B, T, N, H]
value_tensor
=
self
.
_value_dense
(
to_tensor
)
if
cache
:
key_tensor
,
value_tensor
=
self
.
_update_cache
(
key_tensor
,
value_tensor
,
cache
,
decode_loop_step
)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores
=
tf
.
einsum
(
"BTNH,BFNH->BNFT"
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_head_size
)))
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
attention_probs
=
self
.
_masked_softmax
([
attention_scores
,
attention_mask
])
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs
=
self
.
_dropout
(
attention_probs
)
# `context_layer` = [B, F, N, H]
return
tf
.
einsum
(
"BNFT,BTNH->BFNH"
,
attention_probs
,
value_tensor
),
cache
official/nlp/modeling/layers/attention_test.py
View file @
0b579232
...
@@ -88,5 +88,70 @@ class AttentionLayerTest(keras_parameterized.TestCase):
...
@@ -88,5 +88,70 @@ class AttentionLayerTest(keras_parameterized.TestCase):
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
12
,
64
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
12
,
64
])
if
__name__
==
'__main__'
:
def
_create_cache
(
batch_size
,
init_decode_length
,
num_heads
,
head_size
):
return
{
"key"
:
tf
.
zeros
([
batch_size
,
init_decode_length
,
num_heads
,
head_size
],
dtype
=
tf
.
float32
),
"value"
:
tf
.
zeros
([
batch_size
,
init_decode_length
,
num_heads
,
head_size
],
dtype
=
tf
.
float32
)
}
@
keras_parameterized
.
run_all_keras_modes
class
CachedAttentionTest
(
keras_parameterized
.
TestCase
):
def
test_masked_attention
(
self
):
"""Test with a mask tensor."""
num_heads
,
head_size
=
2
,
2
# Create a 3-dimensional input (the first dimension is implicit).
from_seq_length
=
4
batch_size
=
3
# GPU/CPU case.
init_decode_length
=
0
# Directly tests the keras layer.
cache
=
_create_cache
(
batch_size
,
init_decode_length
,
num_heads
,
head_size
)
layer
=
attention
.
CachedAttention
(
num_heads
=
num_heads
,
head_size
=
head_size
)
# Generate data for the input (non-mask) tensors.
from_data
=
tf
.
zeros
((
batch_size
,
from_seq_length
,
8
),
dtype
=
np
.
float32
)
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
from_seq_length
,
from_seq_length
))
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
,
mask_data
,
cache
])
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
2
,
2
))
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
# Tests inputs without cache.
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
,
mask_data
])
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
2
,
2
))
self
.
assertIsNone
(
cache
)
def
test_padded_decode
(
self
):
"""Test with a mask tensor."""
num_heads
,
head_size
=
2
,
2
from_seq_length
=
4
# TPU decoding should pre-allocate the entire sequence.
batch_size
=
3
init_decode_length
=
from_seq_length
# Directly tests the keras layer.
cache
=
_create_cache
(
batch_size
,
init_decode_length
,
num_heads
,
head_size
)
layer
=
attention
.
CachedAttention
(
num_heads
=
num_heads
,
head_size
=
head_size
)
# Generate data for the input (non-mask) tensors.
from_data
=
tf
.
zeros
((
batch_size
,
from_seq_length
,
8
),
dtype
=
np
.
float32
)
decode_loop_step
=
2
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
from_seq_length
,
from_seq_length
),
dtype
=
np
.
int32
)
# Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
,
mask_data
,
cache
],
decode_loop_step
=
decode_loop_step
)
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
2
,
2
))
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment