Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fffea332
Commit
fffea332
authored
Sep 09, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Sep 09, 2020
Browse files
MultiHeadRelativeAttention compatibility changes with XLNet
PiperOrigin-RevId: 330751568
parent
cb6d8d6a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
86 additions
and
44 deletions
+86
-44
official/nlp/modeling/layers/attention.py
official/nlp/modeling/layers/attention.py
+86
-44
No files found.
official/nlp/modeling/layers/attention.py
View file @
fffea332
...
@@ -20,14 +20,29 @@ import string
...
@@ -20,14 +20,29 @@ import string
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
masked_softmax
EinsumDense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
EinsumDense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
MultiHeadAttention
=
tf
.
keras
.
layers
.
MultiHeadAttention
MultiHeadAttention
=
tf
.
keras
.
layers
.
MultiHeadAttention
_CHR_IDX
=
string
.
ascii_lowercase
_CHR_IDX
=
string
.
ascii_lowercase
def
_large_compatible_negative
(
tensor_type
):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
if
tensor_type
==
tf
.
float16
:
return
tf
.
float16
.
min
return
-
1e9
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
CachedAttention
(
tf
.
keras
.
layers
.
MultiHeadAttention
):
class
CachedAttention
(
tf
.
keras
.
layers
.
MultiHeadAttention
):
"""Attention layer with cache used for auto-agressive decoding.
"""Attention layer with cache used for auto-agressive decoding.
...
@@ -116,14 +131,15 @@ class CachedAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -116,14 +131,15 @@ class CachedAttention(tf.keras.layers.MultiHeadAttention):
def
_rel_shift
(
x
,
klen
=-
1
):
def
_rel_shift
(
x
,
klen
=-
1
):
"""Performs relative shift to form the relative attention score."""
"""Performs relative shift to form the relative attention score."""
x
=
tf
.
transpose
(
x
,
perm
=
[
1
,
2
,
0
,
3
])
x
=
tf
.
transpose
(
x
,
perm
=
[
2
,
3
,
0
,
1
])
x_size
=
tf
.
shape
(
x
)
x_size
=
tf
.
shape
(
x
)
x
=
tf
.
reshape
(
x
,
[
x_size
[
1
],
x_size
[
0
],
x_size
[
2
],
x_size
[
3
]])
x
=
tf
.
reshape
(
x
,
[
x_size
[
1
],
x_size
[
0
],
x_size
[
2
],
x_size
[
3
]])
x
=
tf
.
slice
(
x
,
[
1
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
-
1
])
x
=
tf
.
slice
(
x
,
[
1
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
-
1
])
x
=
tf
.
reshape
(
x
,
[
x_size
[
0
],
x_size
[
1
]
-
1
,
x_size
[
2
],
x_size
[
3
]])
x
=
tf
.
reshape
(
x
,
[
x_size
[
0
],
x_size
[
1
]
-
1
,
x_size
[
2
],
x_size
[
3
]])
x
=
tf
.
slice
(
x
,
[
0
,
0
,
0
,
0
],
[
-
1
,
klen
,
-
1
,
-
1
])
x
=
tf
.
slice
(
x
,
[
0
,
0
,
0
,
0
],
[
-
1
,
klen
,
-
1
,
-
1
])
x
=
tf
.
transpose
(
x
,
perm
=
[
2
,
0
,
1
,
3
])
x
=
tf
.
transpose
(
x
,
perm
=
[
2
,
3
,
0
,
1
])
return
x
return
x
...
@@ -200,15 +216,17 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
...
@@ -200,15 +216,17 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
to certain positions.
to certain positions.
"""
"""
def
__init__
(
self
,
kernel_initializer
=
"variance_scaling"
,
**
kwargs
):
super
().
__init__
(
kernel_initializer
=
kernel_initializer
,
**
kwargs
)
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
super
(
MultiHeadRelativeAttention
,
self
).
_build_from_signature
(
super
(
MultiHeadRelativeAttention
,
self
).
_build_from_signature
(
query
=
query
,
query
=
query
,
value
=
value
,
value
=
value
,
key
=
key
)
key
=
key
)
if
hasattr
(
query
,
"shape"
):
query_shape
=
tf
.
TensorShape
(
query
.
shape
)
else
:
query_shape
=
query
if
hasattr
(
value
,
"shape"
):
if
hasattr
(
value
,
"shape"
):
value_shape
=
tf
.
TensorShape
(
value
.
shape
)
value_shape
=
tf
.
TensorShape
(
value
.
shape
)
else
:
else
:
...
@@ -230,36 +248,16 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
...
@@ -230,36 +248,16 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
bias_constraint
=
self
.
_bias_constraint
)
bias_constraint
=
self
.
_bias_constraint
)
with
tf
.
init_scope
():
with
tf
.
init_scope
():
free_dims
=
query_shape
.
rank
-
1
einsum_equation
,
_
,
output_rank
=
_build_proj_equation
(
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
key_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
key_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_encoding_dense
=
EinsumDense
(
self
.
_encoding_dense
=
EinsumDense
(
einsum_equation
,
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_key_dim
]),
[
self
.
_num_heads
,
self
.
_key_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
bias_axes
=
None
,
name
=
"encoding"
,
name
=
"encoding"
,
**
common_kwargs
)
**
common_kwargs
)
output_shape
=
[
query_shape
[
-
1
]]
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
free_dims
,
bound_dims
=
2
,
output_dims
=
len
(
output_shape
))
# TODO(allencwang) - replace all einsums with programmatic equations.
einsum_equation
=
"abcd,ecd->abe"
self
.
_output_dense
=
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
output_shape
),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"attention_output"
,
**
common_kwargs
)
def
_build_attention
(
self
,
rank
):
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
1
],
normalization_axes
=
[
2
])
self
.
_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
def
compute_attention
(
self
,
def
compute_attention
(
self
,
query
,
query
,
key
,
key
,
...
@@ -267,6 +265,9 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
...
@@ -267,6 +265,9 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
position
,
position
,
content_attention_bias
,
content_attention_bias
,
positional_attention_bias
,
positional_attention_bias
,
segment_matrix
=
None
,
segment_encoding
=
None
,
segment_attention_bias
=
None
,
attention_mask
=
None
):
attention_mask
=
None
):
"""Computes the attention.
"""Computes the attention.
...
@@ -282,33 +283,59 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
...
@@ -282,33 +283,59 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
when calculating the content-based attention score.
when calculating the content-based attention score.
positional_attention_bias: Trainable bias parameter added to the query
positional_attention_bias: Trainable bias parameter added to the query
head when calculating the position-based attention score.
head when calculating the position-based attention score.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional trainable `Tensor` representing the
segmentation encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
attention_mask: (default None) Optional mask that is added to attention
attention_mask: (default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should
logits. If state is not None, the mask source sequence dimension should
extend M.
extend M.
Returns:
Returns:
attention_output: Multi-headed output of attention computation of shape
attention_output: Multi-headed output of attention computation of shape
`[B,
T
, N, key_dim]`.
`[B,
S
, N, key_dim]`.
"""
"""
content_attention
=
tf
.
einsum
(
"bind,bjnd->bijn"
,
content_attention
=
tf
.
einsum
(
self
.
_dot_product_equation
,
query
+
content_attention_bias
,
key
,
key
)
query
+
content_attention_bias
)
positional_attention
=
tf
.
einsum
(
self
.
_dot_product_equation
,
position
,
query
+
positional_attention_bias
)
positional_attention
=
_rel_shift
(
positional_attention
,
klen
=
tf
.
shape
(
content_attention
)[
3
])
if
segment_matrix
is
not
None
:
segment_attention
=
tf
.
einsum
(
"bind,snd->bnis"
,
query
+
segment_attention_bias
,
segment_encoding
)
target_shape
=
tf
.
shape
(
positional_attention
)
segment_attention
=
tf
.
where
(
tf
.
broadcast_to
(
tf
.
expand_dims
(
segment_matrix
,
1
),
target_shape
),
tf
.
broadcast_to
(
segment_attention
[:,
:,
:,
1
:],
target_shape
),
tf
.
broadcast_to
(
segment_attention
[:,
:,
:,
:
1
],
target_shape
))
attention_sum
=
(
content_attention
+
positional_attention
+
segment_attention
)
else
:
attention_sum
=
content_attention
+
positional_attention
positional_attention
=
tf
.
einsum
(
"bind,bjnd->bijn"
,
attention_scores
=
tf
.
multiply
(
query
+
positional_attention_bias
,
attention_sum
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
)))
position
)
positional_attention
=
_rel_shift
(
# `attention_scores`: `[B, N, S, S + M]`
positional_attention
,
klen
=
tf
.
shape
(
content_attention
)[
2
])
if
attention_mask
is
not
None
:
attention_scores
+=
(
_large_compatible_negative
(
attention_scores
.
dtype
)
*
attention_mask
)
attention_scores
=
tf
.
multiply
((
content_attention
+
positional_attention
),
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
3
)
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
)))
attention_scores
=
self
.
_masked_softmax
(
attention_scores
,
attention_mask
)
attention_output
=
self
.
_dropout_layer
(
attention_scores
)
attention_output
=
self
.
_dropout_layer
(
attention_scores
)
attention_output
=
tf
.
einsum
(
"bijn,bjnd->bind"
,
attention_output
,
value
)
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_output
,
value
)
return
attention_output
return
attention_output
def
call
(
self
,
def
call
(
self
,
...
@@ -318,6 +345,9 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
...
@@ -318,6 +345,9 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
positional_attention_bias
,
positional_attention_bias
,
key
=
None
,
key
=
None
,
relative_position_encoding
=
None
,
relative_position_encoding
=
None
,
segment_matrix
=
None
,
segment_encoding
=
None
,
segment_attention_bias
=
None
,
state
=
None
,
state
=
None
,
attention_mask
=
None
):
attention_mask
=
None
):
"""Compute multi-head relative attention over inputs.
"""Compute multi-head relative attention over inputs.
...
@@ -342,6 +372,13 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
...
@@ -342,6 +372,13 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
key: attention input.
key: attention input.
relative_position_encoding: relative positional encoding for key and
relative_position_encoding: relative positional encoding for key and
value.
value.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
state: (default None) optional state. If passed, this is also attended
state: (default None) optional state. If passed, this is also attended
over as in TransformerXL.
over as in TransformerXL.
attention_mask: (default None) Optional mask that is added to attention
attention_mask: (default None) Optional mask that is added to attention
...
@@ -381,7 +418,12 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
...
@@ -381,7 +418,12 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
position
=
position
,
position
=
position
,
content_attention_bias
=
content_attention_bias
,
content_attention_bias
=
content_attention_bias
,
positional_attention_bias
=
positional_attention_bias
,
positional_attention_bias
=
positional_attention_bias
,
segment_matrix
=
segment_matrix
,
segment_encoding
=
segment_encoding
,
segment_attention_bias
=
segment_attention_bias
,
attention_mask
=
attention_mask
)
attention_mask
=
attention_mask
)
# `attention_output` = [B, S, N, H]
attention_output
=
self
.
_output_dense
(
attention_output
)
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
return
attention_output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment