Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9cdb5d72
Commit
9cdb5d72
authored
Jun 11, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Jun 11, 2020
Browse files
Internal change
PiperOrigin-RevId: 316053809
parent
b426f52d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
171 additions
and
165 deletions
+171
-165
official/nlp/modeling/layers/talking_heads_attention.py
official/nlp/modeling/layers/talking_heads_attention.py
+82
-144
official/nlp/modeling/layers/talking_heads_attention_test.py
official/nlp/modeling/layers/talking_heads_attention_test.py
+89
-21
No files found.
official/nlp/modeling/layers/talking_heads_attention.py
View file @
9cdb5d72
...
...
@@ -15,26 +15,40 @@
"""Talking Head Attention layer."""
# pylint: disable=g-classes-have-attributes
import
math
import
string
import
gin
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
dense_einsum
from
official.nlp.modeling.layers
import
masked_softmax
from
official.nlp.modeling.layers
import
attention
_CHR_IDX
=
string
.
ascii_lowercase
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
gin
.
configurable
class
TalkingHeadsAttention
(
tf
.
keras
.
layers
.
Layer
):
class
TalkingHeadsAttention
(
attention
.
MultiHeadAttention
):
"""Implements Talking-Heads Attention.
https://arxiv.org/abs/2003.02436
This is an implementation of Talking-Heads Attention based on the paper
Talking-Heads Attention (https://arxiv.org/abs/2003.02436): it enhanced
multi-head attention by including linearprojections across the attention-heads
dimension, immediately before and after the softmax operation.
See the base class `MultiHeadAttention` for more details.
Arguments:
num_heads: Number of attention heads.
key_size: Size of each attention head.
key_size: Size of each attention head for query and key.
value_size: Size of each attention head for value.
dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head attention
scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
...
...
@@ -44,85 +58,34 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
bias_constraint: Constraint for dense layer kernels.
"""
def
__init__
(
self
,
num_heads
,
key_size
,
dropout
=
0.0
,
output_shape
=
None
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
**
kwargs
):
super
(
TalkingHeadsAttention
,
self
).
__init__
(
**
kwargs
)
self
.
_num_heads
=
num_heads
self
.
_key_size
=
key_size
self
.
_dropout
=
dropout
self
.
_output_shape
=
output_shape
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_query_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_key_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"query"
)
self
.
_key_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_key_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"key"
)
self
.
_value_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_key_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"value"
)
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
1
])
self
.
_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
def
build
(
self
,
input_shape
):
if
self
.
_output_shape
:
output_shape
=
self
.
_output_shape
else
:
input_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
output_shape
=
input_shape
[
-
1
]
self
.
_output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
output_shape
,
num_summed_dimensions
=
2
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"attention_output"
)
def
_build_attention
(
self
,
qkv_rank
):
"""Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection
that will be applied on attention scores before and after softmax.
Args:
qkv_rank: the rank of query, key, value tensors after projection.
"""
super
(
TalkingHeadsAttention
,
self
).
_build_attention
(
qkv_rank
)
# Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
# (<batch_dims>, num_heads_b, ...)
# qkv_ranks has `batch_dims`, `attention_dims`, `num_heads` and `channels`.
num_batch_dims
=
qkv_rank
-
len
(
self
.
_attention_axes
)
-
2
# The shape of attn_scores is:
# (<batch_dims>, num_heads, <query_attn_dims>, <key_attn_dims>)
attn_scores_rank
=
num_batch_dims
+
1
+
len
(
self
.
_attention_axes
)
*
2
scores_notation
=
_CHR_IDX
[:
attn_scores_rank
]
projection_notation
=
scores_notation
[
num_batch_dims
]
+
(
_CHR_IDX
[
attn_scores_rank
])
projected_scores_notation
=
scores_notation
[:
num_batch_dims
]
+
(
_CHR_IDX
[
attn_scores_rank
]
+
scores_notation
[
num_batch_dims
+
1
:])
self
.
_talking_heads_equation
=
"%s,%s->%s"
%
(
scores_notation
,
projection_notation
,
projected_scores_notation
)
self
.
_pre_softmax_weight
=
self
.
add_weight
(
"pre_softmax_weight"
,
shape
=
(
self
.
_num_heads
,
self
.
_num_heads
),
...
...
@@ -139,77 +102,52 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
constraint
=
self
.
_kernel_constraint
,
dtype
=
self
.
dtype
,
trainable
=
True
)
super
(
TalkingHeadsAttention
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
"num_heads"
:
self
.
_num_heads
,
"key_size"
:
self
.
_key_size
,
"dropout"
:
self
.
_dropout
,
"kernel_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
"bias_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_bias_initializer
),
"kernel_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_kernel_regularizer
),
"bias_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_bias_regularizer
),
"activity_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_activity_regularizer
),
"kernel_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
)
}
base_config
=
super
(
TalkingHeadsAttention
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
,
attention_mask
=
None
):
from_tensor
=
inputs
[
0
]
to_tensor
=
inputs
[
1
]
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = L = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor
=
self
.
_query_dense
(
from_tensor
)
# `key_tensor` = [B, T, N, H]
key_tensor
=
self
.
_key_dense
(
to_tensor
)
# `value_tensor` = [B, T, N, H]
value_tensor
=
self
.
_value_dense
(
to_tensor
)
def
_compute_attention
(
self
,
query_tensor
,
key_tensor
,
value_tensor
,
attention_mask
=
None
):
"""Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection
on attention scores before and after softmax.
Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
Returns:
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores
=
tf
.
einsum
(
"BTNH,BFNH->BNFT"
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
# Apply
talking heads
before softmax
.
attention_scores
=
tf
.
einsum
(
"BNFT,NL->BLFT"
,
attention_scores
,
# Apply
linear projection
before softmax
attention_scores
=
tf
.
einsum
(
self
.
_talking_heads_equation
,
attention_scores
,
self
.
_pre_softmax_weight
)
# Normalize the attention scores to probabilities.
# `attention_
prob
s` = [B, N,
F
,
T
]
attention_
prob
s
=
self
.
_masked_softmax
(
attention_scores
,
attention_mask
)
# `attention_
score
s` = [B, N,
T
,
S
]
attention_
score
s
=
self
.
_masked_softmax
(
attention_scores
,
attention_mask
)
# Apply
talking heads
after softmax
.
attention_
prob
s
=
tf
.
einsum
(
"BNFT,NL->BLFT"
,
attention_
prob
s
,
# Apply
linear projection
after softmax
attention_
score
s
=
tf
.
einsum
(
self
.
_talking_heads_equation
,
attention_
score
s
,
self
.
_post_softmax_weight
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_
probs
=
self
.
_dropout
(
attention_
prob
s
)
attention_
scores_dropout
=
self
.
_dropout
_layer
(
attention_
score
s
)
# `context_layer` = [B, F, N, H]
attention_output
=
tf
.
einsum
(
"BNFT,BTNH->BFNH"
,
attention_probs
,
value_tensor
)
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
# `context_layer` = [B, T, N, H]
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_scores_dropout
,
value_tensor
)
return
attention_output
,
attention_scores
official/nlp/modeling/layers/talking_heads_attention_test.py
View file @
9cdb5d72
...
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
...
...
@@ -27,58 +28,97 @@ from official.nlp.modeling.layers import talking_heads_attention
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
# This test is revised base on attention.MultiHeadAttentionTest.
@
keras_parameterized
.
run_all_keras_modes
class
Multi
HeadAttentionTest
(
keras_parameterized
.
TestCase
):
class
Talking
Head
s
AttentionTest
(
keras_parameterized
.
TestCase
):
def
test_non_masked_attention
(
self
):
@
parameterized
.
named_parameters
(
(
"key_value_same_proj"
,
None
,
None
,
[
40
,
80
]),
(
"key_value_different_proj"
,
32
,
60
,
[
40
,
60
]),
)
def
test_non_masked_attention
(
self
,
value_size
,
output_shape
,
output_dims
):
"""Test that the attention layer can be created without a mask tensor."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
12
,
key_size
=
64
)
num_heads
=
12
,
key_size
=
64
,
value_size
=
value_size
,
output_shape
=
output_shape
)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
to_tensor
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
output
=
test_layer
([
from_tensor
,
to_tensor
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
]
)
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
value
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
output
=
test_layer
([
query
,
value
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
]
+
output_dims
)
def
test_non_masked_self_attention
(
self
):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
12
,
key_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
([
from_tensor
,
from_tensor
])
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
([
query
,
query
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
def
test_masked_attention
(
self
):
def
test_attention_scores
(
self
):
"""Test attention outputs with coefficients."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
12
,
key_size
=
64
,
return_attention_scores
=
True
)
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
,
coef
=
test_layer
([
query
,
query
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
coef
.
shape
.
as_list
(),
[
None
,
12
,
40
,
40
])
@
parameterized
.
named_parameters
((
"with_bias"
,
True
),
(
"no_bias"
,
False
))
def
test_masked_attention
(
self
,
use_bias
):
"""Test with a mask tensor."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
2
,
key_size
=
2
)
num_heads
=
1
2
,
key_size
=
2
,
use_bias
=
use_bias
)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
to_tensor
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
batch_size
=
3
query
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
value
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
output
=
test_layer
([
from_tensor
,
to_tensor
],
mask_tensor
)
output
=
test_layer
([
query
,
value
],
mask_tensor
)
# Create a model containing the test layer.
model
=
tf
.
keras
.
Model
([
from_tensor
,
to_tensor
,
mask_tensor
],
output
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
mask_tensor
],
output
)
# Generate data for the input (non-mask) tensors.
from_data
=
10
*
np
.
random
.
random_sample
((
3
,
4
,
8
))
to_data
=
10
*
np
.
random
.
random_sample
((
3
,
2
,
8
))
from_data
=
10
*
np
.
random
.
random_sample
((
batch_size
,
4
,
8
))
to_data
=
10
*
np
.
random
.
random_sample
((
batch_size
,
2
,
8
))
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
4
,
2
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
4
,
2
))
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
mask_data
])
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data
=
np
.
ones
((
3
,
4
,
2
))
null_mask_data
=
np
.
ones
((
batch_size
,
4
,
2
))
unmasked_output_data
=
model
.
predict
([
from_data
,
to_data
,
null_mask_data
])
# Because one data is masked and one is not, the outputs should not be the
# same.
self
.
assertNotAllClose
(
masked_output_data
,
unmasked_output_data
)
# Tests the layer with three inputs: Q, K, V.
key
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
output
=
test_layer
([
query
,
value
,
key
],
mask_tensor
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
key
,
mask_tensor
],
output
)
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
to_data
,
mask_data
])
unmasked_output_data
=
model
.
predict
(
[
from_data
,
to_data
,
to_data
,
null_mask_data
])
# Because one data is masked and one is not, the outputs should not be the
# same.
self
.
assertNotAllClose
(
masked_output_data
,
unmasked_output_data
)
if
use_bias
:
self
.
assertLen
(
test_layer
.
_query_dense
.
trainable_variables
,
2
)
self
.
assertLen
(
test_layer
.
_output_dense
.
trainable_variables
,
2
)
else
:
self
.
assertLen
(
test_layer
.
_query_dense
.
trainable_variables
,
1
)
self
.
assertLen
(
test_layer
.
_output_dense
.
trainable_variables
,
1
)
def
test_initializer
(
self
):
"""Test with a specified initializer."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
...
...
@@ -86,10 +126,38 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
key_size
=
64
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
([
from_tensor
,
from_tensor
])
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
([
query
,
query
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
@
parameterized
.
named_parameters
(
(
"4d_inputs_one_free_batch"
,
[
3
,
4
],
[
3
,
2
],
[
4
,
2
],
(
2
,)),
(
"4D_inputs_2D_attention"
,
[
3
,
4
],
[
3
,
2
],
[
3
,
4
,
3
,
2
],
(
1
,
2
)),
(
"5D_inputs_2D_attention"
,
[
5
,
3
,
4
],
[
5
,
3
,
2
],
[
3
,
4
,
3
,
2
],
(
2
,
3
)))
def
test_high_dim_attention
(
self
,
q_dims
,
v_dims
,
mask_dims
,
attention_axes
):
"""Test with a mask tensor."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
12
,
key_size
=
2
,
attention_axes
=
attention_axes
)
batch_size
,
hidden_size
=
3
,
8
# Generate data for the input (non-mask) tensors.
query_shape
=
[
batch_size
]
+
q_dims
+
[
hidden_size
]
value_shape
=
[
batch_size
]
+
v_dims
+
[
hidden_size
]
mask_shape
=
[
batch_size
]
+
mask_dims
query
=
10
*
np
.
random
.
random_sample
(
query_shape
)
value
=
10
*
np
.
random
.
random_sample
(
value_shape
)
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"bool"
)
output
=
test_layer
([
query
,
value
],
mask_data
)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data
=
np
.
ones
(
mask_shape
)
unmasked_output
=
test_layer
([
query
,
value
],
null_mask_data
)
# Because one data is masked and one is not, the outputs should not be the
# same.
self
.
assertNotAllClose
(
output
,
unmasked_output
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment