Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2e41d8ca
Commit
2e41d8ca
authored
Sep 14, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Sep 14, 2020
Browse files
Move `MultiHeadRelativeAttention` to relative_attention.py.
PiperOrigin-RevId: 331652519
parent
1c5dba9e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
475 additions
and
354 deletions
+475
-354
official/nlp/modeling/layers/attention.py
official/nlp/modeling/layers/attention.py
+0
-321
official/nlp/modeling/layers/attention_test.py
official/nlp/modeling/layers/attention_test.py
+0
-33
official/nlp/modeling/layers/relative_attention.py
official/nlp/modeling/layers/relative_attention.py
+346
-0
official/nlp/modeling/layers/relative_attention_test.py
official/nlp/modeling/layers/relative_attention_test.py
+129
-0
No files found.
official/nlp/modeling/layers/attention.py
View file @
2e41d8ca
...
...
@@ -16,31 +16,11 @@
"""Keras-based attention layer."""
# pylint: disable=g-classes-have-attributes
import
math
import
string
import
tensorflow
as
tf
EinsumDense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
MultiHeadAttention
=
tf
.
keras
.
layers
.
MultiHeadAttention
_CHR_IDX
=
string
.
ascii_lowercase
def
_large_compatible_negative
(
tensor_type
):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
if
tensor_type
==
tf
.
float16
:
return
tf
.
float16
.
min
return
-
1e9
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
...
...
@@ -126,304 +106,3 @@ class CachedAttention(tf.keras.layers.MultiHeadAttention):
if
return_attention_scores
:
return
attention_output
,
attention_scores
,
cache
return
attention_output
,
cache
def
_rel_shift
(
x
,
klen
=-
1
):
"""Performs relative shift to form the relative attention score."""
x
=
tf
.
transpose
(
x
,
perm
=
[
2
,
3
,
0
,
1
])
x_size
=
tf
.
shape
(
x
)
x
=
tf
.
reshape
(
x
,
[
x_size
[
1
],
x_size
[
0
],
x_size
[
2
],
x_size
[
3
]])
x
=
tf
.
slice
(
x
,
[
1
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
-
1
])
x
=
tf
.
reshape
(
x
,
[
x_size
[
0
],
x_size
[
1
]
-
1
,
x_size
[
2
],
x_size
[
3
]])
x
=
tf
.
slice
(
x
,
[
0
,
0
,
0
,
0
],
[
-
1
,
klen
,
-
1
,
-
1
])
x
=
tf
.
transpose
(
x
,
perm
=
[
2
,
3
,
0
,
1
])
return
x
def
_build_proj_equation
(
free_dims
,
bound_dims
,
output_dims
):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str
=
""
kernel_str
=
""
output_str
=
""
bias_axes
=
""
letter_offset
=
0
for
i
in
range
(
free_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
input_str
+=
char
output_str
+=
char
letter_offset
+=
free_dims
for
i
in
range
(
bound_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
input_str
+=
char
kernel_str
+=
char
letter_offset
+=
bound_dims
for
i
in
range
(
output_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
kernel_str
+=
char
output_str
+=
char
bias_axes
+=
char
equation
=
"%s,%s->%s"
%
(
input_str
,
kernel_str
,
output_str
)
return
equation
,
bias_axes
,
len
(
output_str
)
def
_get_output_shape
(
output_rank
,
known_last_dims
):
return
[
None
]
*
(
output_rank
-
len
(
known_last_dims
))
+
list
(
known_last_dims
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
MultiHeadRelativeAttention
(
MultiHeadAttention
):
"""A multi-head attention layer with relative attention + position encoding.
This layer shares the same input/output projections as the common
MultiHeadAttention layer.
When it calculates attention logits, position encoding is projected to form
relative keys. The logits are composed by shifted relative logits and content
logits.
**Note: This layer is currently experimental.
Arguments:
num_heads: The number of attention heads.
key_dim: Size of each attention head for query and key.
value_dim: Size of attention head for value.
dropout: Dropout probability for attention.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
content_attention_bias: Bias `Tensor` for content based attention of shape
`[num_heads, dim]`.
position_attention_bias: Bias `Tensor` for position based attention of shape
`[num_heads, dim]`.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
state: Optional `Tensor` of shape [B, M, E] where M is the length of the
state or memory.
If passed, this is also attended over as in Transformer XL.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
def
__init__
(
self
,
kernel_initializer
=
"variance_scaling"
,
**
kwargs
):
super
().
__init__
(
kernel_initializer
=
kernel_initializer
,
**
kwargs
)
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
super
(
MultiHeadRelativeAttention
,
self
).
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
if
hasattr
(
value
,
"shape"
):
value_shape
=
tf
.
TensorShape
(
value
.
shape
)
else
:
value_shape
=
value
if
key
is
None
:
key_shape
=
value_shape
elif
hasattr
(
key
,
"shape"
):
key_shape
=
tf
.
TensorShape
(
key
.
shape
)
else
:
key_shape
=
key
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
with
tf
.
init_scope
():
einsum_equation
,
_
,
output_rank
=
_build_proj_equation
(
key_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_encoding_dense
=
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_key_dim
]),
bias_axes
=
None
,
name
=
"encoding"
,
**
common_kwargs
)
def
compute_attention
(
self
,
query
,
key
,
value
,
position
,
content_attention_bias
,
positional_attention_bias
,
segment_matrix
=
None
,
segment_encoding
=
None
,
segment_attention_bias
=
None
,
attention_mask
=
None
):
"""Computes the attention.
This function defines the computation inside `call` with projected
multihead Q, K, V, R inputs.
Args:
query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
key: Projected key `Tensor` of shape `[B, S + M, N, key_dim]`.
value: Projected value `Tensor` of shape `[B, S + M, N, key_dim]`.
position: Projected position `Tensor` of shape `[B, L, N, key_dim]`.
content_attention_bias: Trainable bias parameter added to the query head
when calculating the content-based attention score.
positional_attention_bias: Trainable bias parameter added to the query
head when calculating the position-based attention score.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional trainable `Tensor` representing the
segmentation encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
attention_mask: (default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should
extend M.
Returns:
attention_output: Multi-headed output of attention computation of shape
`[B, S, N, key_dim]`.
"""
content_attention
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key
,
query
+
content_attention_bias
)
positional_attention
=
tf
.
einsum
(
self
.
_dot_product_equation
,
position
,
query
+
positional_attention_bias
)
positional_attention
=
_rel_shift
(
positional_attention
,
klen
=
tf
.
shape
(
content_attention
)[
3
])
if
segment_matrix
is
not
None
:
segment_attention
=
tf
.
einsum
(
"bind,snd->bnis"
,
query
+
segment_attention_bias
,
segment_encoding
)
target_shape
=
tf
.
shape
(
positional_attention
)
segment_attention
=
tf
.
where
(
tf
.
broadcast_to
(
tf
.
expand_dims
(
segment_matrix
,
1
),
target_shape
),
tf
.
broadcast_to
(
segment_attention
[:,
:,
:,
1
:],
target_shape
),
tf
.
broadcast_to
(
segment_attention
[:,
:,
:,
:
1
],
target_shape
))
attention_sum
=
(
content_attention
+
positional_attention
+
segment_attention
)
else
:
attention_sum
=
content_attention
+
positional_attention
attention_scores
=
tf
.
multiply
(
attention_sum
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
)))
# `attention_scores`: `[B, N, S, S + M]`
if
attention_mask
is
not
None
:
attention_scores
+=
(
_large_compatible_negative
(
attention_scores
.
dtype
)
*
attention_mask
)
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
3
)
attention_output
=
self
.
_dropout_layer
(
attention_scores
)
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_output
,
value
)
return
attention_output
def
call
(
self
,
query
,
value
,
content_attention_bias
,
positional_attention_bias
,
key
=
None
,
relative_position_encoding
=
None
,
segment_matrix
=
None
,
segment_encoding
=
None
,
segment_attention_bias
=
None
,
state
=
None
,
attention_mask
=
None
):
"""Compute multi-head relative attention over inputs.
Size glossary:
* Number of heads (H): the number of attention heads.
* Value size (V): the size of each value embedding per head.
* Key size (K): the size of each key embedding per head. Equally, the size
of each query embedding per head. Typically K <= V.
* Batch dimensions (B).
* Query (target) attention axes shape (T).
* Value (source) attention axes shape (S), the rank must match the target.
* Encoding length (L): The relative positional encoding length.
Args:
query: attention input.
value: attention input.
content_attention_bias: A trainable bias parameter added to the query
head when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score.
key: attention input.
relative_position_encoding: relative positional encoding for key and
value.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
state: (default None) optional state. If passed, this is also attended
over as in TransformerXL.
attention_mask: (default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should
extend M.
Returns:
attention_output: The result of the computation, of shape [B, T, E],
where `T` is for target sequence shapes and `E` is the query input last
dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
are projected to the shape specified by `output_shape`.
"""
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
,
value
,
key
=
key
)
if
key
is
None
:
key
=
value
if
state
is
not
None
and
state
.
shape
.
ndims
>
1
:
value
=
tf
.
concat
([
state
,
value
],
1
)
key
=
tf
.
concat
([
state
,
key
],
1
)
# `query` = [B, T, N ,H]
query
=
self
.
_query_dense
(
query
)
# `key` = [B, S + M, N, H]
key
=
self
.
_key_dense
(
key
)
# `value` = [B, S + M, N, H]
value
=
self
.
_value_dense
(
value
)
# `position` = [B, L, N, H]
position
=
self
.
_encoding_dense
(
relative_position_encoding
)
attention_output
=
self
.
compute_attention
(
query
=
query
,
key
=
key
,
value
=
value
,
position
=
position
,
content_attention_bias
=
content_attention_bias
,
positional_attention_bias
=
positional_attention_bias
,
segment_matrix
=
segment_matrix
,
segment_encoding
=
segment_encoding
,
segment_attention_bias
=
segment_attention_bias
,
attention_mask
=
attention_mask
)
# `attention_output` = [B, S, N, H]
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
official/nlp/modeling/layers/attention_test.py
View file @
2e41d8ca
...
...
@@ -92,38 +92,5 @@ class CachedAttentionTest(keras_parameterized.TestCase):
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
@
keras_parameterized
.
run_all_keras_modes
class
MultiHeadRelativeAttentionTest
(
keras_parameterized
.
TestCase
):
def
test_attention_scores
(
self
):
num_heads
=
12
key_dim
=
64
value_dim
=
32
seq_length
=
8
batch_size
=
2
test_layer
=
attention
.
MultiHeadRelativeAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
value_dim
=
value_dim
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
value
=
query
relative_position_encoding
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
*
2
,
key_dim
))
content_attention_bias
=
tf
.
random
.
normal
(
shape
=
(
num_heads
,
key_dim
))
positional_attention_bias
=
tf
.
random
.
normal
(
shape
=
(
num_heads
,
key_dim
))
output
=
test_layer
(
query
=
query
,
value
=
value
,
content_attention_bias
=
content_attention_bias
,
positional_attention_bias
=
positional_attention_bias
,
relative_position_encoding
=
relative_position_encoding
,
state
=
None
,
attention_mask
=
None
)
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/modeling/layers/relative_attention.py
0 → 100644
View file @
2e41d8ca
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based relative attention layers."""
import
math
import
string
import
tensorflow
as
tf
_CHR_IDX
=
string
.
ascii_lowercase
def
_build_proj_equation
(
free_dims
,
bound_dims
,
output_dims
):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str
=
""
kernel_str
=
""
output_str
=
""
bias_axes
=
""
letter_offset
=
0
for
i
in
range
(
free_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
input_str
+=
char
output_str
+=
char
letter_offset
+=
free_dims
for
i
in
range
(
bound_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
input_str
+=
char
kernel_str
+=
char
letter_offset
+=
bound_dims
for
i
in
range
(
output_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
kernel_str
+=
char
output_str
+=
char
bias_axes
+=
char
equation
=
"%s,%s->%s"
%
(
input_str
,
kernel_str
,
output_str
)
return
equation
,
bias_axes
,
len
(
output_str
)
def
_get_output_shape
(
output_rank
,
known_last_dims
):
return
[
None
]
*
(
output_rank
-
len
(
known_last_dims
))
+
list
(
known_last_dims
)
def
_large_compatible_negative
(
tensor_type
):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
if
tensor_type
==
tf
.
float16
:
return
tf
.
float16
.
min
return
-
1e9
def
_rel_shift
(
x
,
klen
=-
1
):
"""Performs relative shift to form the relative attention score."""
x
=
tf
.
transpose
(
x
,
perm
=
[
2
,
3
,
0
,
1
])
x_size
=
tf
.
shape
(
x
)
x
=
tf
.
reshape
(
x
,
[
x_size
[
1
],
x_size
[
0
],
x_size
[
2
],
x_size
[
3
]])
x
=
tf
.
slice
(
x
,
[
1
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
-
1
])
x
=
tf
.
reshape
(
x
,
[
x_size
[
0
],
x_size
[
1
]
-
1
,
x_size
[
2
],
x_size
[
3
]])
x
=
tf
.
slice
(
x
,
[
0
,
0
,
0
,
0
],
[
-
1
,
klen
,
-
1
,
-
1
])
x
=
tf
.
transpose
(
x
,
perm
=
[
2
,
3
,
0
,
1
])
return
x
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
MultiHeadRelativeAttention
(
tf
.
keras
.
layers
.
MultiHeadAttention
):
"""A multi-head attention layer with relative attention + position encoding.
This layer shares the same input/output projections as the common
MultiHeadAttention layer.
When it calculates attention logits, position encoding is projected to form
relative keys. The logits are composed by shifted relative logits and content
logits.
**Note: This layer is currently experimental.
Attributes:
num_heads: The number of attention heads.
key_dim: Size of each attention head for query and key.
value_dim: Size of attention head for value.
dropout: Dropout probability for attention.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
content_attention_bias: Bias `Tensor` for content based attention of shape
`[num_heads, dim]`.
position_attention_bias: Bias `Tensor` for position based attention of shape
`[num_heads, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet of shape `[B, S, S + M]`.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet of shape `[2, num_heads, dim]`.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet of shape `[num_heads, dim]`.
state: Optional `Tensor` of shape [B, M, E] where M is the length of the
state or memory.
If passed, this is also attended over as in Transformer XL.
attention_mask: a boolean mask of shape `[B, T, S]` that prevents attention
to certain positions.
"""
def
__init__
(
self
,
kernel_initializer
=
"variance_scaling"
,
**
kwargs
):
super
().
__init__
(
kernel_initializer
=
kernel_initializer
,
**
kwargs
)
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
super
(
MultiHeadRelativeAttention
,
self
).
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
if
hasattr
(
value
,
"shape"
):
value_shape
=
tf
.
TensorShape
(
value
.
shape
)
else
:
value_shape
=
value
if
key
is
None
:
key_shape
=
value_shape
elif
hasattr
(
key
,
"shape"
):
key_shape
=
tf
.
TensorShape
(
key
.
shape
)
else
:
key_shape
=
key
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
with
tf
.
init_scope
():
einsum_equation
,
_
,
output_rank
=
_build_proj_equation
(
key_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_encoding_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_key_dim
]),
bias_axes
=
None
,
name
=
"encoding"
,
**
common_kwargs
)
def
compute_attention
(
self
,
query
,
key
,
value
,
position
,
content_attention_bias
,
positional_attention_bias
,
segment_matrix
=
None
,
segment_encoding
=
None
,
segment_attention_bias
=
None
,
attention_mask
=
None
):
"""Computes the attention.
This function defines the computation inside `call` with projected
multihead Q, K, V, R inputs.
Args:
query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
key: Projected key `Tensor` of shape `[B, S + M, N, key_dim]`.
value: Projected value `Tensor` of shape `[B, S + M, N, key_dim]`.
position: Projected position `Tensor` of shape `[B, L, N, key_dim]`.
content_attention_bias: Trainable bias parameter added to the query head
when calculating the content-based attention score.
positional_attention_bias: Trainable bias parameter added to the query
head when calculating the position-based attention score.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional trainable `Tensor` representing the
segmentation encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
attention_mask: (default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should
extend M.
Returns:
attention_output: Multi-headed output of attention computation of shape
`[B, S, N, key_dim]`.
"""
content_attention
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key
,
query
+
content_attention_bias
)
positional_attention
=
tf
.
einsum
(
self
.
_dot_product_equation
,
position
,
query
+
positional_attention_bias
)
positional_attention
=
_rel_shift
(
positional_attention
,
klen
=
tf
.
shape
(
content_attention
)[
3
])
if
segment_matrix
is
not
None
:
segment_attention
=
tf
.
einsum
(
"bind,snd->bnis"
,
query
+
segment_attention_bias
,
segment_encoding
)
target_shape
=
tf
.
shape
(
positional_attention
)
segment_attention
=
tf
.
where
(
tf
.
broadcast_to
(
tf
.
expand_dims
(
segment_matrix
,
1
),
target_shape
),
tf
.
broadcast_to
(
segment_attention
[:,
:,
:,
1
:],
target_shape
),
tf
.
broadcast_to
(
segment_attention
[:,
:,
:,
:
1
],
target_shape
))
attention_sum
=
(
content_attention
+
positional_attention
+
segment_attention
)
else
:
attention_sum
=
content_attention
+
positional_attention
attention_scores
=
tf
.
multiply
(
attention_sum
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
)))
# `attention_scores`: `[B, N, S, S + M]`
if
attention_mask
is
not
None
:
attention_scores
+=
(
_large_compatible_negative
(
attention_scores
.
dtype
)
*
attention_mask
)
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
3
)
attention_output
=
self
.
_dropout_layer
(
attention_scores
)
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_output
,
value
)
return
attention_output
def
call
(
self
,
query
,
value
,
content_attention_bias
,
positional_attention_bias
,
key
=
None
,
relative_position_encoding
=
None
,
segment_matrix
=
None
,
segment_encoding
=
None
,
segment_attention_bias
=
None
,
state
=
None
,
attention_mask
=
None
):
"""Compute multi-head relative attention over inputs.
Size glossary:
* Number of heads (H): the number of attention heads.
* Value size (V): the size of each value embedding per head.
* Key size (K): the size of each key embedding per head. Equally, the size
of each query embedding per head. Typically K <= V.
* Batch dimensions (B).
* Query (target) attention axes shape (T).
* Value (source) attention axes shape (S), the rank must match the target.
* Encoding length (L): The relative positional encoding length.
Args:
query: attention input.
value: attention input.
content_attention_bias: A trainable bias parameter added to the query
head when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score.
key: attention input.
relative_position_encoding: relative positional encoding for key and
value.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
state: (default None) optional state. If passed, this is also attended
over as in TransformerXL.
attention_mask: (default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should
extend M.
Returns:
attention_output: The result of the computation, of shape [B, T, E],
where `T` is for target sequence shapes and `E` is the query input last
dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
are projected to the shape specified by `output_shape`.
"""
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
,
value
,
key
=
key
)
if
key
is
None
:
key
=
value
if
state
is
not
None
and
state
.
shape
.
ndims
>
1
:
value
=
tf
.
concat
([
state
,
value
],
1
)
key
=
tf
.
concat
([
state
,
key
],
1
)
# `query` = [B, T, N ,H]
query
=
self
.
_query_dense
(
query
)
# `key` = [B, S + M, N, H]
key
=
self
.
_key_dense
(
key
)
# `value` = [B, S + M, N, H]
value
=
self
.
_value_dense
(
value
)
# `position` = [B, L, N, H]
position
=
self
.
_encoding_dense
(
relative_position_encoding
)
attention_output
=
self
.
compute_attention
(
query
=
query
,
key
=
key
,
value
=
value
,
position
=
position
,
content_attention_bias
=
content_attention_bias
,
positional_attention_bias
=
positional_attention_bias
,
segment_matrix
=
segment_matrix
,
segment_encoding
=
segment_encoding
,
segment_attention_bias
=
segment_attention_bias
,
attention_mask
=
attention_mask
)
# `attention_output` = [B, S, N, H]
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
official/nlp/modeling/layers/relative_attention_test.py
0 → 100644
View file @
2e41d8ca
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the attention layer."""
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.modeling.layers
import
relative_attention
def
_create_mock_attention_data
(
num_heads
,
key_dim
,
value_dim
,
seq_length
,
batch_size
,
memory_length
=
0
,
include_state
=
False
,
include_mask
=
False
,
include_segment
=
False
):
"""Creates mock testing data.
Args:
num_heads: `int`, Number of attention heads.
key_dim: `int`, Size of query head.
value_dim: `int`, Size of key, value dim.
seq_length: `int`, Sequence length of the input.
batch_size: `int`, the batch size.
memory_length: optional `int`, the length of the state. Defaults to 0.
include_state: optional `bool`, whether or not to include state data.
include_mask: optional `bool`, whether or not to include mask data.
include_segment: optional `bool`, whether or not to include segment data.
Returns:
A dictionary with `str` as keys and `Tensor` as values.
"""
query_shape
=
(
batch_size
,
seq_length
,
key_dim
)
value_shape
=
(
batch_size
,
seq_length
,
value_dim
)
encoding_shape
=
(
batch_size
,
seq_length
*
2
,
key_dim
)
attention_bias_shape
=
(
num_heads
,
key_dim
)
data
=
dict
(
query
=
tf
.
random
.
normal
(
shape
=
query_shape
),
value
=
tf
.
random
.
normal
(
shape
=
value_shape
),
key
=
tf
.
random
.
normal
(
shape
=
value_shape
),
relative_position_encoding
=
tf
.
random
.
normal
(
shape
=
encoding_shape
),
content_attention_bias
=
tf
.
random
.
normal
(
shape
=
attention_bias_shape
),
positional_attention_bias
=
tf
.
random
.
normal
(
shape
=
attention_bias_shape
))
if
include_state
:
total_seq_length
=
seq_length
+
memory_length
state_data
=
dict
(
state
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
memory_length
,
value_dim
)))
data
.
update
(
state_data
)
else
:
total_seq_length
=
seq_length
if
include_mask
:
mask_shape
=
(
batch_size
,
num_heads
,
seq_length
,
total_seq_length
)
mask_data
=
dict
(
attention_mask
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"float32"
))
data
.
update
(
mask_data
)
if
include_segment
:
segment_encoding_shape
=
(
2
,
num_heads
,
key_dim
)
segment_matrix
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
seq_length
,
total_seq_length
))
segment_matrix
=
tf
.
math
.
equal
(
segment_matrix
,
1
)
segment_data
=
dict
(
segment_attention_bias
=
tf
.
random
.
normal
(
shape
=
attention_bias_shape
),
segment_encoding
=
tf
.
random
.
normal
(
shape
=
segment_encoding_shape
),
segment_matrix
=
segment_matrix
)
data
.
update
(
segment_data
)
return
data
@
keras_parameterized
.
run_all_keras_modes
class
MultiHeadRelativeAttentionTest
(
keras_parameterized
.
TestCase
):
@
combinations
.
generate
(
combinations
.
combine
(
value_dim
=
[
32
,
64
],
memory_length
=
[
0
,
4
],
state
=
[
True
,
False
],
mask
=
[
True
,
False
],
segment
=
[
True
,
False
]))
def
test_attention_scores
(
self
,
value_dim
,
memory_length
,
state
,
mask
,
segment
):
"""Tests combinations of attention score calculations."""
batch_size
,
num_heads
,
key_dim
,
seq_length
=
2
,
12
,
64
,
8
test_layer
=
relative_attention
.
MultiHeadRelativeAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
value_dim
=
value_dim
)
data
=
_create_mock_attention_data
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
value_dim
=
value_dim
,
seq_length
=
seq_length
,
memory_length
=
memory_length
,
batch_size
=
batch_size
,
include_state
=
state
,
include_mask
=
mask
,
include_segment
=
segment
)
output
=
test_layer
(
**
data
)
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
0
)
tf
.
random
.
set_seed
(
0
)
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment