Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
09c5ae2f
Commit
09c5ae2f
authored
May 09, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
May 09, 2020
Browse files
Internal change
PiperOrigin-RevId: 310767440
parent
52e4ded8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
177 additions
and
75 deletions
+177
-75
official/nlp/modeling/layers/attention.py
official/nlp/modeling/layers/attention.py
+161
-56
official/nlp/modeling/layers/attention_test.py
official/nlp/modeling/layers/attention_test.py
+11
-2
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+0
-4
official/nlp/nhnet/decoder.py
official/nlp/nhnet/decoder.py
+5
-13
No files found.
official/nlp/modeling/layers/attention.py
View file @
09c5ae2f
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -19,12 +20,98 @@ from __future__ import division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
collections
import
math
import
string
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
dense_einsum
from
official.nlp.modeling.layers
import
masked_softmax
EinsumDense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
_CHR_IDX
=
string
.
ascii_lowercase
def
_build_attention_equation
(
qkv_rank
,
attn_axes
):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
(bs, <non-attention dims>, <attention dims>, num_heads, channels).
bs and <non-attention dims> are treated as <batch dims>.
The attention operations can be generalized:
(1) Query-key dot product:
(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
<key attention dims>, num_heads, channels) -> (<batch dims>,
num_heads, <query attention dims>, <key attention dims>)
(2) Combination:
(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
(<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>,
<query attention dims>, num_heads, channels)
Args:
qkv_rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns:
Einsum equations.
"""
target_notation
=
_CHR_IDX
[:
qkv_rank
]
# `batch_dims` includes the head dim.
batch_dims
=
tuple
(
np
.
delete
(
range
(
qkv_rank
),
attn_axes
+
(
qkv_rank
-
1
,)))
letter_offset
=
qkv_rank
source_notation
=
""
for
i
in
range
(
qkv_rank
):
if
i
in
batch_dims
or
i
==
qkv_rank
-
1
:
source_notation
+=
target_notation
[
i
]
else
:
source_notation
+=
_CHR_IDX
[
letter_offset
]
letter_offset
+=
1
product_notation
=
""
.
join
([
target_notation
[
i
]
for
i
in
batch_dims
]
+
[
target_notation
[
i
]
for
i
in
attn_axes
]
+
[
source_notation
[
i
]
for
i
in
attn_axes
])
dot_product_equation
=
"%s,%s->%s"
%
(
source_notation
,
target_notation
,
product_notation
)
combine_equation
=
"%s,%s->%s"
%
(
product_notation
,
source_notation
,
target_notation
)
return
dot_product_equation
,
combine_equation
def
_build_proj_equation
(
free_dims
,
bound_dims
,
output_dims
):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str
=
""
kernel_str
=
""
output_str
=
""
bias_axes
=
""
letter_offset
=
0
for
i
in
range
(
free_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
input_str
+=
char
output_str
+=
char
letter_offset
+=
free_dims
for
i
in
range
(
bound_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
input_str
+=
char
kernel_str
+=
char
letter_offset
+=
bound_dims
for
i
in
range
(
output_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
kernel_str
+=
char
output_str
+=
char
bias_axes
+=
char
equation
=
"%s,%s->%s"
%
(
input_str
,
kernel_str
,
output_str
)
# The output rank does not consider the batch dimension.
output_rank
=
len
(
output_str
)
-
1
return
equation
,
bias_axes
,
output_rank
def
_get_output_shape
(
output_rank
,
known_last_dims
):
return
[
None
]
*
(
output_rank
-
len
(
known_last_dims
))
+
list
(
known_last_dims
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
MultiHeadAttention
(
tf
.
keras
.
layers
.
Layer
):
...
...
@@ -53,7 +140,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
key_size: Size of each attention head for query and key.
value_size: Size of each attention head for value.
dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors.
use_bias: Boolean, whether the dense layers use bias vectors
/matrices
.
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
kernel_initializer: Initializer for dense layer kernels.
...
...
@@ -94,44 +181,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_query_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_key_size
),
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"query"
)
self
.
_key_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_key_size
),
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"key"
)
self
.
_value_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_value_size
),
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"value"
)
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
1
])
self
.
_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
def
get_config
(
self
):
...
...
@@ -167,22 +217,72 @@ class MultiHeadAttention(tf.keras.layers.Layer):
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
build
(
self
,
input_shape
):
if
self
.
_output_shape
:
output_shape
=
self
.
_output_shape
else
:
input_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
output_shape
=
input_shape
[
-
1
]
self
.
_output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
output_shape
,
num_summed_dimensions
=
2
,
inputs_len
=
len
(
input_shape
)
if
inputs_len
>
3
or
inputs_len
<
2
:
raise
ValueError
(
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. "
"Given length: %d"
%
inputs_len
)
tensor_shapes
=
tf
.
nest
.
map_structure
(
tf
.
TensorShape
,
input_shape
)
query_shape
=
tensor_shapes
[
0
]
value_shape
=
tensor_shapes
[
1
]
key_shape
=
tensor_shapes
[
2
]
if
inputs_len
==
3
else
value_shape
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"attention_output"
)
bias_constraint
=
self
.
_bias_constraint
)
free_dims
=
query_shape
.
rank
-
1
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
free_dims
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_query_dense
=
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
,
[
self
.
_num_heads
,
self
.
_key_size
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"query"
,
**
common_kwargs
)
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
key_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_key_dense
=
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
,
[
self
.
_num_heads
,
self
.
_key_size
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"key"
,
**
common_kwargs
)
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
value_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_value_dense
=
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
,
[
self
.
_num_heads
,
self
.
_value_size
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"value"
,
**
common_kwargs
)
self
.
_dot_product_equation
,
self
.
_combine_equation
=
(
_build_attention_equation
(
output_rank
+
1
,
attn_axes
=
(
1
,)))
if
self
.
_output_shape
:
if
not
isinstance
(
self
.
_output_shape
,
collections
.
abc
.
Sized
):
output_shape
=
[
self
.
_output_shape
]
else
:
output_shape
=
self
.
_output_shape
else
:
output_shape
=
[
query_shape
[
-
1
]]
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
free_dims
,
bound_dims
=
2
,
output_dims
=
len
(
output_shape
))
self
.
_output_dense
=
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
,
output_shape
),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"attention_output"
,
**
common_kwargs
)
super
(
MultiHeadAttention
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
attention_mask
=
None
):
...
...
@@ -234,7 +334,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores
=
tf
.
einsum
(
"BSNH,BTNH->BNTS"
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
...
...
@@ -247,7 +348,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention_probs
=
self
.
_dropout
(
attention_probs
)
# `context_layer` = [B, T, N, H]
attention_output
=
tf
.
einsum
(
"BNTS,BSNH->BTNH"
,
attention_probs
,
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_probs
,
value_tensor
)
attention_output
=
self
.
_output_dense
(
attention_output
)
...
...
@@ -288,11 +389,14 @@ class CachedAttention(MultiHeadAttention):
return
key_tensor
,
value_tensor
def
call
(
self
,
inputs
,
decode_loop_step
=
None
):
def
call
(
self
,
inputs
,
attention_mask
=
None
,
cache
=
None
,
decode_loop_step
=
None
):
from_tensor
=
inputs
[
0
]
to_tensor
=
inputs
[
1
]
attention_mask
=
inputs
[
2
]
if
len
(
inputs
)
>=
3
else
None
cache
=
inputs
[
3
]
if
len
(
inputs
)
>=
4
else
None
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
...
...
@@ -314,7 +418,8 @@ class CachedAttention(MultiHeadAttention):
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores
=
tf
.
einsum
(
"BTNH,BFNH->BNFT"
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
einsum
(
self
.
_dot_product_equation
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_size
)))
...
...
@@ -326,7 +431,7 @@ class CachedAttention(MultiHeadAttention):
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs
=
self
.
_dropout
(
attention_probs
)
# `context_layer` = [B, F, N, H]
attention_output
=
tf
.
einsum
(
"BNFT,BTNH->BFNH"
,
attention_probs
,
attention_output
=
tf
.
einsum
(
self
.
_combine_equation
,
attention_probs
,
value_tensor
)
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
,
cache
official/nlp/modeling/layers/attention_test.py
View file @
09c5ae2f
...
...
@@ -99,6 +99,13 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# same.
self
.
assertNotAllClose
(
masked_output_data
,
unmasked_output_data
)
if
use_bias
:
self
.
assertLen
(
test_layer
.
_query_dense
.
trainable_variables
,
2
)
self
.
assertLen
(
test_layer
.
_output_dense
.
trainable_variables
,
2
)
else
:
self
.
assertLen
(
test_layer
.
_query_dense
.
trainable_variables
,
1
)
self
.
assertLen
(
test_layer
.
_output_dense
.
trainable_variables
,
1
)
def
test_initializer
(
self
):
"""Test with a specified initializer."""
test_layer
=
attention
.
MultiHeadAttention
(
...
...
@@ -143,7 +150,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# one element.
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
from_seq_length
,
from_seq_length
))
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
,
mask_data
,
cache
]
)
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
]
,
mask_data
,
cache
)
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
8
))
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
...
...
@@ -170,7 +177,9 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
from_seq_length
,
from_seq_length
),
dtype
=
np
.
int32
)
# Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
,
mask_data
,
cache
],
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
],
mask_data
,
cache
,
decode_loop_step
=
decode_loop_step
)
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
8
))
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
...
...
official/nlp/modeling/layers/transformer.py
View file @
09c5ae2f
...
...
@@ -116,10 +116,6 @@ class Transformer(tf.keras.layers.Layer):
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"self_attention"
)
# TODO(hongkuny): Remove when checkpoint backward compatibility is resolved.
# pylint: disable=protected-access
self
.
_attention_layer
.
build
([
input_tensor_shape
])
self
.
_attention_output_dense
=
self
.
_attention_layer
.
_output_dense
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
# Use float32 in layernorm for numeric stability.
...
...
official/nlp/nhnet/decoder.py
View file @
09c5ae2f
...
...
@@ -95,12 +95,6 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
output_shape
=
self
.
hidden_size
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"attention/encdec"
)
# TODO(hongkuny): Remove when checkpoint backward compatibility is resolved.
# pylint: disable=protected-access
self
.
self_attention
.
build
(
input_shape
)
self
.
self_attention_output_dense
=
self
.
self_attention
.
_output_dense
self
.
encdec_attention
.
build
(
input_shape
)
self
.
encdec_attention_output_dense
=
self
.
encdec_attention
.
_output_dense
self
.
encdec_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
hidden_dropout_prob
)
...
...
@@ -145,14 +139,12 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"TransformerDecoderBlock must have 4 inputs, but it got: %d"
%
len
(
inputs
))
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
[:
4
]
if
cache
is
None
:
self_attention_inputs
=
[
input_tensor
,
input_tensor
,
self_attention_mask
]
else
:
self_attention_inputs
=
[
input_tensor
,
input_tensor
,
self_attention_mask
,
cache
]
self_attention_inputs
=
[
input_tensor
,
input_tensor
]
self_attention_output
,
cache
=
self
.
self_attention
(
self_attention_inputs
,
decode_loop_step
=
decode_loop_step
)
self_attention_inputs
,
attention_mask
=
self_attention_mask
,
cache
=
cache
,
decode_loop_step
=
decode_loop_step
)
self_attention_output
=
self
.
self_attention_dropout
(
self_attention_output
)
self_attention_output
=
self
.
self_attention_layer_norm
(
input_tensor
+
self_attention_output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment