Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9cdb5d72
Commit
9cdb5d72
authored
Jun 11, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Jun 11, 2020
Browse files
Internal change
PiperOrigin-RevId: 316053809
parent
b426f52d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
171 additions
and
165 deletions
+171
-165
official/nlp/modeling/layers/talking_heads_attention.py
official/nlp/modeling/layers/talking_heads_attention.py
+82
-144
official/nlp/modeling/layers/talking_heads_attention_test.py
official/nlp/modeling/layers/talking_heads_attention_test.py
+89
-21
No files found.
official/nlp/modeling/layers/talking_heads_attention.py
View file @
9cdb5d72
...
@@ -15,26 +15,40 @@
...
@@ -15,26 +15,40 @@
"""Talking Head Attention layer."""
"""Talking Head Attention layer."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
import
math
import
math
import
string
import
gin
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
dense_einsum
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
masked_softmax
_CHR_IDX
=
string
.
ascii_lowercase
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
gin
.
configurable
@
gin
.
configurable
class
TalkingHeadsAttention
(
tf
.
keras
.
layers
.
Layer
):
class
TalkingHeadsAttention
(
attention
.
MultiHeadAttention
):
"""Implements Talking-Heads Attention.
"""Implements Talking-Heads Attention.
https://arxiv.org/abs/2003.02436
This is an implementation of Talking-Heads Attention based on the paper
Talking-Heads Attention (https://arxiv.org/abs/2003.02436): it enhanced
multi-head attention by including linearprojections across the attention-heads
dimension, immediately before and after the softmax operation.
See the base class `MultiHeadAttention` for more details.
Arguments:
Arguments:
num_heads: Number of attention heads.
num_heads: Number of attention heads.
key_size: Size of each attention head.
key_size: Size of each attention head for query and key.
value_size: Size of each attention head for value.
dropout: Dropout probability.
dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
output_shape: The expected shape of an output tensor, besides the batch and
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head attention
scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
kernel_regularizer: Regularizer for dense layer kernels.
...
@@ -44,85 +58,34 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
...
@@ -44,85 +58,34 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
bias_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
"""
def
__init__
(
self
,
def
_build_attention
(
self
,
qkv_rank
):
num_heads
,
"""Builds multi-head dot-product attention computations.
key_size
,
dropout
=
0.0
,
This function overrides base class to create additional linear projection
output_shape
=
None
,
that will be applied on attention scores before and after softmax.
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
Args:
kernel_regularizer
=
None
,
qkv_rank: the rank of query, key, value tensors after projection.
bias_regularizer
=
None
,
"""
activity_regularizer
=
None
,
super
(
TalkingHeadsAttention
,
self
).
_build_attention
(
qkv_rank
)
kernel_constraint
=
None
,
bias_constraint
=
None
,
# Build an equation:
**
kwargs
):
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
super
(
TalkingHeadsAttention
,
self
).
__init__
(
**
kwargs
)
# (<batch_dims>, num_heads_b, ...)
self
.
_num_heads
=
num_heads
# qkv_ranks has `batch_dims`, `attention_dims`, `num_heads` and `channels`.
self
.
_key_size
=
key_size
num_batch_dims
=
qkv_rank
-
len
(
self
.
_attention_axes
)
-
2
self
.
_dropout
=
dropout
self
.
_output_shape
=
output_shape
# The shape of attn_scores is:
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
# (<batch_dims>, num_heads, <query_attn_dims>, <key_attn_dims>)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
attn_scores_rank
=
num_batch_dims
+
1
+
len
(
self
.
_attention_axes
)
*
2
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
scores_notation
=
_CHR_IDX
[:
attn_scores_rank
]
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
projection_notation
=
scores_notation
[
num_batch_dims
]
+
(
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
_CHR_IDX
[
attn_scores_rank
])
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
projected_scores_notation
=
scores_notation
[:
num_batch_dims
]
+
(
_CHR_IDX
[
attn_scores_rank
]
+
scores_notation
[
num_batch_dims
+
1
:])
self
.
_query_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_talking_heads_equation
=
"%s,%s->%s"
%
(
output_shape
=
(
self
.
_num_heads
,
self
.
_key_size
),
scores_notation
,
projection_notation
,
projected_scores_notation
)
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"query"
)
self
.
_key_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_key_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"key"
)
self
.
_value_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_key_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"value"
)
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
1
])
self
.
_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
def
build
(
self
,
input_shape
):
if
self
.
_output_shape
:
output_shape
=
self
.
_output_shape
else
:
input_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
output_shape
=
input_shape
[
-
1
]
self
.
_output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
output_shape
,
num_summed_dimensions
=
2
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"attention_output"
)
self
.
_pre_softmax_weight
=
self
.
add_weight
(
self
.
_pre_softmax_weight
=
self
.
add_weight
(
"pre_softmax_weight"
,
"pre_softmax_weight"
,
shape
=
(
self
.
_num_heads
,
self
.
_num_heads
),
shape
=
(
self
.
_num_heads
,
self
.
_num_heads
),
...
@@ -139,77 +102,52 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
...
@@ -139,77 +102,52 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
constraint
=
self
.
_kernel_constraint
,
constraint
=
self
.
_kernel_constraint
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
trainable
=
True
)
trainable
=
True
)
super
(
TalkingHeadsAttention
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
"num_heads"
:
self
.
_num_heads
,
"key_size"
:
self
.
_key_size
,
"dropout"
:
self
.
_dropout
,
"kernel_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
"bias_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_bias_initializer
),
"kernel_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_kernel_regularizer
),
"bias_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_bias_regularizer
),
"activity_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_activity_regularizer
),
"kernel_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
)
}
base_config
=
super
(
TalkingHeadsAttention
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
,
attention_mask
=
None
):
from_tensor
=
inputs
[
0
]
to_tensor
=
inputs
[
1
]
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = L = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor
=
self
.
_query_dense
(
from_tensor
)
# `key_tensor` = [B, T, N, H]
key_tensor
=
self
.
_key_dense
(
to_tensor
)
# `value_tensor` = [B, T, N, H]
value_tensor
=
self
.
_value_dense
(
to_tensor
)
def
_compute_attention
(
self
,
query_tensor
,
key_tensor
,
value_tensor
,
attention_mask
=
None
):
"""Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection
on attention scores before and after softmax.
Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
Returns:
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Take the dot product between "query" and "key" to get the raw
# Take the dot product between "query" and "key" to get the raw
# attention scores.
# attention scores.
attention_scores
=
tf
.
einsum
(
"BTNH,BFNH->BNFT"
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
# Apply
talking heads
before softmax
.
# Apply
linear projection
before softmax
attention_scores
=
tf
.
einsum
(
"BNFT,NL->BLFT"
,
attention_scores
,
attention_scores
=
tf
.
einsum
(
self
.
_talking_heads_equation
,
attention_scores
,
self
.
_pre_softmax_weight
)
self
.
_pre_softmax_weight
)
# Normalize the attention scores to probabilities.
# Normalize the attention scores to probabilities.
# `attention_
prob
s` = [B, N,
F
,
T
]
# `attention_
score
s` = [B, N,
T
,
S
]
attention_
prob
s
=
self
.
_masked_softmax
(
attention_scores
,
attention_mask
)
attention_
score
s
=
self
.
_masked_softmax
(
attention_scores
,
attention_mask
)
# Apply
talking heads
after softmax
.
# Apply
linear projection
after softmax
attention_
prob
s
=
tf
.
einsum
(
"BNFT,NL->BLFT"
,
attention_
prob
s
,
attention_
score
s
=
tf
.
einsum
(
self
.
_talking_heads_equation
,
attention_
score
s
,
self
.
_post_softmax_weight
)
self
.
_post_softmax_weight
)
# This is actually dropping out entire tokens to attend to, which might
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# seem a bit unusual, but is taken from the original Transformer paper.
attention_
probs
=
self
.
_dropout
(
attention_
prob
s
)
attention_
scores_dropout
=
self
.
_dropout
_layer
(
attention_
score
s
)
# `context_layer` = [B, F, N, H]
# `context_layer` = [B, T, N, H]
attention_output
=
tf
.
einsum
(
"BNFT,BTNH->BFNH"
,
attention_probs
,
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
value_tensor
)
attention_scores_dropout
,
value_tensor
)
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
,
attention_scores
return
attention_output
official/nlp/modeling/layers/talking_heads_attention_test.py
View file @
9cdb5d72
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -27,58 +28,97 @@ from official.nlp.modeling.layers import talking_heads_attention
...
@@ -27,58 +28,97 @@ from official.nlp.modeling.layers import talking_heads_attention
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
# guarantees forward compatibility of this code for the V2 switchover.
# This test is revised base on attention.MultiHeadAttentionTest.
@
keras_parameterized
.
run_all_keras_modes
@
keras_parameterized
.
run_all_keras_modes
class
Multi
HeadAttentionTest
(
keras_parameterized
.
TestCase
):
class
Talking
Head
s
AttentionTest
(
keras_parameterized
.
TestCase
):
def
test_non_masked_attention
(
self
):
@
parameterized
.
named_parameters
(
(
"key_value_same_proj"
,
None
,
None
,
[
40
,
80
]),
(
"key_value_different_proj"
,
32
,
60
,
[
40
,
60
]),
)
def
test_non_masked_attention
(
self
,
value_size
,
output_shape
,
output_dims
):
"""Test that the attention layer can be created without a mask tensor."""
"""Test that the attention layer can be created without a mask tensor."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
12
,
key_size
=
64
)
num_heads
=
12
,
key_size
=
64
,
value_size
=
value_size
,
output_shape
=
output_shape
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
to_tensor
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
value
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
output
=
test_layer
([
from_tensor
,
to_tensor
])
output
=
test_layer
([
query
,
value
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
]
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
]
+
output_dims
)
def
test_non_masked_self_attention
(
self
):
def
test_non_masked_self_attention
(
self
):
"""Test with one input (self-attenntion) and no mask tensor."""
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
12
,
key_size
=
64
)
num_heads
=
12
,
key_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
([
from_tensor
,
from_tensor
])
output
=
test_layer
([
query
,
query
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
def
test_masked_attention
(
self
):
def
test_attention_scores
(
self
):
"""Test attention outputs with coefficients."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
12
,
key_size
=
64
,
return_attention_scores
=
True
)
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
,
coef
=
test_layer
([
query
,
query
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
coef
.
shape
.
as_list
(),
[
None
,
12
,
40
,
40
])
@
parameterized
.
named_parameters
((
"with_bias"
,
True
),
(
"no_bias"
,
False
))
def
test_masked_attention
(
self
,
use_bias
):
"""Test with a mask tensor."""
"""Test with a mask tensor."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
2
,
key_size
=
2
)
num_heads
=
1
2
,
key_size
=
2
,
use_bias
=
use_bias
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
batch_size
=
3
to_tensor
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
query
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
value
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
output
=
test_layer
([
from_tensor
,
to_tensor
],
mask_tensor
)
output
=
test_layer
([
query
,
value
],
mask_tensor
)
# Create a model containing the test layer.
# Create a model containing the test layer.
model
=
tf
.
keras
.
Model
([
from_tensor
,
to_tensor
,
mask_tensor
],
output
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
mask_tensor
],
output
)
# Generate data for the input (non-mask) tensors.
# Generate data for the input (non-mask) tensors.
from_data
=
10
*
np
.
random
.
random_sample
((
3
,
4
,
8
))
from_data
=
10
*
np
.
random
.
random_sample
((
batch_size
,
4
,
8
))
to_data
=
10
*
np
.
random
.
random_sample
((
3
,
2
,
8
))
to_data
=
10
*
np
.
random
.
random_sample
((
batch_size
,
2
,
8
))
# Invoke the data with a random set of mask data. This should mask at least
# Invoke the data with a random set of mask data. This should mask at least
# one element.
# one element.
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
4
,
2
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
4
,
2
))
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
mask_data
])
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
mask_data
])
# Invoke the same data, but with a null mask (where no elements are masked).
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data
=
np
.
ones
((
3
,
4
,
2
))
null_mask_data
=
np
.
ones
((
batch_size
,
4
,
2
))
unmasked_output_data
=
model
.
predict
([
from_data
,
to_data
,
null_mask_data
])
unmasked_output_data
=
model
.
predict
([
from_data
,
to_data
,
null_mask_data
])
# Because one data is masked and one is not, the outputs should not be the
# Because one data is masked and one is not, the outputs should not be the
# same.
# same.
self
.
assertNotAllClose
(
masked_output_data
,
unmasked_output_data
)
self
.
assertNotAllClose
(
masked_output_data
,
unmasked_output_data
)
# Tests the layer with three inputs: Q, K, V.
key
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
output
=
test_layer
([
query
,
value
,
key
],
mask_tensor
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
key
,
mask_tensor
],
output
)
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
to_data
,
mask_data
])
unmasked_output_data
=
model
.
predict
(
[
from_data
,
to_data
,
to_data
,
null_mask_data
])
# Because one data is masked and one is not, the outputs should not be the
# same.
self
.
assertNotAllClose
(
masked_output_data
,
unmasked_output_data
)
if
use_bias
:
self
.
assertLen
(
test_layer
.
_query_dense
.
trainable_variables
,
2
)
self
.
assertLen
(
test_layer
.
_output_dense
.
trainable_variables
,
2
)
else
:
self
.
assertLen
(
test_layer
.
_query_dense
.
trainable_variables
,
1
)
self
.
assertLen
(
test_layer
.
_output_dense
.
trainable_variables
,
1
)
def
test_initializer
(
self
):
def
test_initializer
(
self
):
"""Test with a specified initializer."""
"""Test with a specified initializer."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
...
@@ -86,10 +126,38 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
...
@@ -86,10 +126,38 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
key_size
=
64
,
key_size
=
64
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
([
from_tensor
,
from_tensor
])
output
=
test_layer
([
query
,
query
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
@
parameterized
.
named_parameters
(
(
"4d_inputs_one_free_batch"
,
[
3
,
4
],
[
3
,
2
],
[
4
,
2
],
(
2
,)),
(
"4D_inputs_2D_attention"
,
[
3
,
4
],
[
3
,
2
],
[
3
,
4
,
3
,
2
],
(
1
,
2
)),
(
"5D_inputs_2D_attention"
,
[
5
,
3
,
4
],
[
5
,
3
,
2
],
[
3
,
4
,
3
,
2
],
(
2
,
3
)))
def
test_high_dim_attention
(
self
,
q_dims
,
v_dims
,
mask_dims
,
attention_axes
):
"""Test with a mask tensor."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
12
,
key_size
=
2
,
attention_axes
=
attention_axes
)
batch_size
,
hidden_size
=
3
,
8
# Generate data for the input (non-mask) tensors.
query_shape
=
[
batch_size
]
+
q_dims
+
[
hidden_size
]
value_shape
=
[
batch_size
]
+
v_dims
+
[
hidden_size
]
mask_shape
=
[
batch_size
]
+
mask_dims
query
=
10
*
np
.
random
.
random_sample
(
query_shape
)
value
=
10
*
np
.
random
.
random_sample
(
value_shape
)
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"bool"
)
output
=
test_layer
([
query
,
value
],
mask_data
)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data
=
np
.
ones
(
mask_shape
)
unmasked_output
=
test_layer
([
query
,
value
],
null_mask_data
)
# Because one data is masked and one is not, the outputs should not be the
# same.
self
.
assertNotAllClose
(
output
,
unmasked_output
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment