Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
330b34fe
Commit
330b34fe
authored
Apr 21, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Apr 21, 2020
Browse files
Internal change
PiperOrigin-RevId: 307689094
parent
74556d99
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
258 additions
and
177 deletions
+258
-177
official/nlp/albert/tf2_albert_encoder_checkpoint_converter.py
...ial/nlp/albert/tf2_albert_encoder_checkpoint_converter.py
+1
-1
official/nlp/bert/tf1_checkpoint_converter_lib.py
official/nlp/bert/tf1_checkpoint_converter_lib.py
+3
-3
official/nlp/modeling/layers/attention.py
official/nlp/modeling/layers/attention.py
+112
-38
official/nlp/modeling/layers/attention_test.py
official/nlp/modeling/layers/attention_test.py
+48
-25
official/nlp/modeling/layers/rezero_transformer.py
official/nlp/modeling/layers/rezero_transformer.py
+2
-17
official/nlp/modeling/layers/talking_heads_attention.py
official/nlp/modeling/layers/talking_heads_attention.py
+36
-14
official/nlp/modeling/layers/talking_heads_attention_test.py
official/nlp/modeling/layers/talking_heads_attention_test.py
+8
-8
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+7
-17
official/nlp/modeling/layers/transformer_scaffold.py
official/nlp/modeling/layers/transformer_scaffold.py
+2
-6
official/nlp/modeling/layers/transformer_scaffold_test.py
official/nlp/modeling/layers/transformer_scaffold_test.py
+11
-10
official/nlp/nhnet/decoder.py
official/nlp/nhnet/decoder.py
+16
-18
official/nlp/nhnet/multi_channel_attention.py
official/nlp/nhnet/multi_channel_attention.py
+9
-16
official/nlp/nhnet/multi_channel_attention_test.py
official/nlp/nhnet/multi_channel_attention_test.py
+3
-3
official/nlp/nhnet/utils.py
official/nlp/nhnet/utils.py
+0
-1
No files found.
official/nlp/albert/tf2_albert_encoder_checkpoint_converter.py
View file @
330b34fe
...
@@ -54,7 +54,7 @@ ALBERT_NAME_REPLACEMENTS = (
...
@@ -54,7 +54,7 @@ ALBERT_NAME_REPLACEMENTS = (
(
"embedding_hidden_mapping_in"
,
"embedding_projection"
),
(
"embedding_hidden_mapping_in"
,
"embedding_projection"
),
(
"group_0/inner_group_0/"
,
""
),
(
"group_0/inner_group_0/"
,
""
),
(
"attention_1/self"
,
"self_attention"
),
(
"attention_1/self"
,
"self_attention"
),
(
"attention_1/output/dense"
,
"self_attention_output"
),
(
"attention_1/output/dense"
,
"self_attention
/attention
_output"
),
(
"LayerNorm/"
,
"self_attention_layer_norm/"
),
(
"LayerNorm/"
,
"self_attention_layer_norm/"
),
(
"ffn_1/intermediate/dense"
,
"intermediate"
),
(
"ffn_1/intermediate/dense"
,
"intermediate"
),
(
"ffn_1/intermediate/output/dense"
,
"output"
),
(
"ffn_1/intermediate/output/dense"
,
"output"
),
...
...
official/nlp/bert/tf1_checkpoint_converter_lib.py
View file @
330b34fe
...
@@ -47,7 +47,7 @@ BERT_V2_NAME_REPLACEMENTS = (
...
@@ -47,7 +47,7 @@ BERT_V2_NAME_REPLACEMENTS = (
(
"embeddings/position_embeddings"
,
"position_embedding/embeddings"
),
(
"embeddings/position_embeddings"
,
"position_embedding/embeddings"
),
(
"embeddings/LayerNorm"
,
"embeddings/layer_norm"
),
(
"embeddings/LayerNorm"
,
"embeddings/layer_norm"
),
(
"attention/self"
,
"self_attention"
),
(
"attention/self"
,
"self_attention"
),
(
"attention/output/dense"
,
"self_attention_output"
),
(
"attention/output/dense"
,
"self_attention
/attention
_output"
),
(
"attention/output/LayerNorm"
,
"self_attention_layer_norm"
),
(
"attention/output/LayerNorm"
,
"self_attention_layer_norm"
),
(
"intermediate/dense"
,
"intermediate"
),
(
"intermediate/dense"
,
"intermediate"
),
(
"output/dense"
,
"output"
),
(
"output/dense"
,
"output"
),
...
@@ -94,9 +94,9 @@ def _get_permutation(name, permutations):
...
@@ -94,9 +94,9 @@ def _get_permutation(name, permutations):
def
_get_new_shape
(
name
,
shape
,
num_heads
):
def
_get_new_shape
(
name
,
shape
,
num_heads
):
"""Checks whether a variable requires reshape by pattern matching."""
"""Checks whether a variable requires reshape by pattern matching."""
if
"self_attention_output/kernel"
in
name
:
if
"self_attention
/attention
_output/kernel"
in
name
:
return
tuple
([
num_heads
,
shape
[
0
]
//
num_heads
,
shape
[
1
]])
return
tuple
([
num_heads
,
shape
[
0
]
//
num_heads
,
shape
[
1
]])
if
"self_attention_output/bias"
in
name
:
if
"self_attention
/attention
_output/bias"
in
name
:
return
shape
return
shape
patterns
=
[
patterns
=
[
...
...
official/nlp/modeling/layers/attention.py
View file @
330b34fe
...
@@ -31,24 +31,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -31,24 +31,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
"""MultiHeadAttention layer.
"""MultiHeadAttention layer.
This is an implementation of multi-headed attention based on "Attention
This is an implementation of multi-headed attention based on "Attention
is all you Need". If `
from_tensor` and `to_tensor
` are the same, then
is all you Need". If `
query`, `key,` `value
` are the same, then
this is self-attention. Each timestep in `
from_tensor
` attends to the
this is self-attention. Each timestep in `
query
` attends to the
corresponding sequence in `
to_tensor
`, and returns a fixed-width vector.
corresponding sequence in `
key
`, and returns a fixed-width vector.
This
function
first projects `
from_tensor` into a "query" tensor and
This
layer
first projects `
query`, `key` and `value`. These are
`to_tensor` into "key" and "value" tensors. These are (effectively) a list
(effectively) a list of tensors of length `num_attention_heads`, where the
of tensors of length `num_attention_heads`, where each tensor is of shape
corresponding shapes are [batch_size, query_seq_length, key_size],
[batch_size, seq_length,
size_per_head
].
[batch_size, seq_length,
key_size], [batch_size, seq_length, value_size
].
Then, the query and key tensors are dot-producted and scaled. These are
Then, the query and key tensors are dot-producted and scaled. These are
softmaxed to obtain attention probabilities. The value tensors are then
softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities, then concatenated back to a single
interpolated by these probabilities, then concatenated back to a single
tensor and returned.
tensor.
Finally, the result tensor with the last dimension as value_size can take an
linear projection and return.
Arguments:
Arguments:
num_heads: Number of attention heads.
num_heads: Number of attention heads.
head_size: Size of each attention head.
key_size: Size of each attention head for query and key.
value_size: Size of each attention head for value.
dropout: Dropout probability.
dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors.
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
kernel_initializer: Initializer for dense layer kernels.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
kernel_regularizer: Regularizer for dense layer kernels.
...
@@ -60,8 +67,11 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -60,8 +67,11 @@ class MultiHeadAttention(tf.keras.layers.Layer):
def
__init__
(
self
,
def
__init__
(
self
,
num_heads
,
num_heads
,
head_size
,
key_size
,
value_size
=
None
,
dropout_rate
=
0.0
,
dropout_rate
=
0.0
,
use_bias
=
True
,
output_shape
=
None
,
kernel_initializer
=
"glorot_uniform"
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
kernel_regularizer
=
None
,
...
@@ -72,8 +82,11 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -72,8 +82,11 @@ class MultiHeadAttention(tf.keras.layers.Layer):
**
kwargs
):
**
kwargs
):
super
(
MultiHeadAttention
,
self
).
__init__
(
**
kwargs
)
super
(
MultiHeadAttention
,
self
).
__init__
(
**
kwargs
)
self
.
_num_heads
=
num_heads
self
.
_num_heads
=
num_heads
self
.
_head_size
=
head_size
self
.
_key_size
=
key_size
self
.
_value_size
=
value_size
if
value_size
else
key_size
self
.
_dropout_rate
=
dropout_rate
self
.
_dropout_rate
=
dropout_rate
self
.
_use_bias
=
use_bias
self
.
_output_shape
=
output_shape
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
...
@@ -82,7 +95,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -82,7 +95,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_query_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_query_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_head_size
),
output_shape
=
(
self
.
_num_heads
,
self
.
_key_size
),
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
...
@@ -93,7 +107,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -93,7 +107,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
name
=
"query"
)
name
=
"query"
)
self
.
_key_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_key_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_head_size
),
output_shape
=
(
self
.
_num_heads
,
self
.
_key_size
),
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
...
@@ -104,7 +119,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -104,7 +119,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
name
=
"key"
)
name
=
"key"
)
self
.
_value_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_value_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_head_size
),
output_shape
=
(
self
.
_num_heads
,
self
.
_value_size
),
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
...
@@ -122,10 +138,16 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -122,10 +138,16 @@ class MultiHeadAttention(tf.keras.layers.Layer):
config
=
{
config
=
{
"num_heads"
:
"num_heads"
:
self
.
_num_heads
,
self
.
_num_heads
,
"head_size"
:
"key_size"
:
self
.
_head_size
,
self
.
_key_size
,
"value_size"
:
self
.
_value_size
,
"dropout_rate"
:
"dropout_rate"
:
self
.
_dropout_rate
,
self
.
_dropout_rate
,
"use_bias"
:
self
.
_use_bias
,
"output_shape"
:
self
.
_output_shape
,
"kernel_initializer"
:
"kernel_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
"bias_initializer"
:
"bias_initializer"
:
...
@@ -144,42 +166,92 @@ class MultiHeadAttention(tf.keras.layers.Layer):
...
@@ -144,42 +166,92 @@ class MultiHeadAttention(tf.keras.layers.Layer):
base_config
=
super
(
MultiHeadAttention
,
self
).
get_config
()
base_config
=
super
(
MultiHeadAttention
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
def
build
(
self
,
input_shape
):
from_tensor
=
inputs
[
0
]
if
self
.
_output_shape
:
to_tensor
=
inputs
[
1
]
output_shape
=
self
.
_output_shape
attention_mask
=
inputs
[
2
]
if
len
(
inputs
)
==
3
else
None
else
:
input_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
output_shape
=
input_shape
[
-
1
]
self
.
_output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
output_shape
,
num_summed_dimensions
=
2
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"attention_output"
)
super
(
MultiHeadAttention
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
attention_mask
=
None
):
"""Implements the forward pass.
Size glossary:
* Number of heads (H): the number of attention heads.
* Value size (V): the size of each value embedding per head.
* Key size (K): the size of each key embedding per head. Equally, the size
of each query embedding per head. Typically K <= V.
* Batch size (B).
* Query (target) sequence length (T).
* Value (source) sequence length (S).
Args:
inputs: List of the following tensors:
* query: Query `Tensor` of shape `[B, T, dim]`.
* value: Value `Tensor` of shape `[B, S, dim]`.
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
use `value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
Returns:
attention_output: The result of the computation, of shape [B, F, N, V] or
[B, F, E], where `N` is the number of heads and `E` is the query input
last dimension.
"""
inputs_len
=
len
(
inputs
)
if
inputs_len
>
3
or
inputs_len
<
2
:
raise
ValueError
(
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. "
"Given length: %d"
%
inputs_len
)
query
=
inputs
[
0
]
value
=
inputs
[
1
]
key
=
inputs
[
2
]
if
inputs_len
==
3
else
value
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# N = `num_attention_heads`
# H = `size_per_head`
# H = `size_per_head`
# `query_tensor` = [B,
F
, N ,H]
# `query_tensor` = [B,
T
, N ,H]
query_tensor
=
self
.
_query_dense
(
from_tensor
)
query_tensor
=
self
.
_query_dense
(
query
)
# `key_tensor` = [B,
T
, N, H]
# `key_tensor` = [B,
S
, N, H]
key_tensor
=
self
.
_key_dense
(
to_tensor
)
key_tensor
=
self
.
_key_dense
(
key
)
# `value_tensor` = [B,
T
, N, H]
# `value_tensor` = [B,
S
, N, H]
value_tensor
=
self
.
_value_dense
(
to_tensor
)
value_tensor
=
self
.
_value_dense
(
value
)
# Take the dot product between "query" and "key" to get the raw
# Take the dot product between "query" and "key" to get the raw
# attention scores.
# attention scores.
attention_scores
=
tf
.
einsum
(
"B
T
NH,B
F
NH->BN
F
T"
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
einsum
(
"B
S
NH,B
T
NH->BNT
S
"
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_
head
_size
)))
1.0
/
math
.
sqrt
(
float
(
self
.
_
key
_size
)))
# Normalize the attention scores to probabilities.
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N,
F
,
T
]
# `attention_probs` = [B, N,
T
,
S
]
attention_probs
=
self
.
_masked_softmax
([
attention_scores
,
attention_mask
])
attention_probs
=
self
.
_masked_softmax
([
attention_scores
,
attention_mask
])
# This is actually dropping out entire tokens to attend to, which might
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs
=
self
.
_dropout
(
attention_probs
)
attention_probs
=
self
.
_dropout
(
attention_probs
)
# `context_layer` = [B, F, N, H]
# `context_layer` = [B, T, N, H]
return
tf
.
einsum
(
"BNFT,BTNH->BFNH"
,
attention_probs
,
value_tensor
)
attention_output
=
tf
.
einsum
(
"BNTS,BSNH->BTNH"
,
attention_probs
,
value_tensor
)
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
...
@@ -244,7 +316,7 @@ class CachedAttention(MultiHeadAttention):
...
@@ -244,7 +316,7 @@ class CachedAttention(MultiHeadAttention):
# attention scores.
# attention scores.
attention_scores
=
tf
.
einsum
(
"BTNH,BFNH->BNFT"
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
einsum
(
"BTNH,BFNH->BNFT"
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_
head
_size
)))
1.0
/
math
.
sqrt
(
float
(
self
.
_
key
_size
)))
# Normalize the attention scores to probabilities.
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
# `attention_probs` = [B, N, F, T]
...
@@ -253,6 +325,8 @@ class CachedAttention(MultiHeadAttention):
...
@@ -253,6 +325,8 @@ class CachedAttention(MultiHeadAttention):
# This is actually dropping out entire tokens to attend to, which might
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs
=
self
.
_dropout
(
attention_probs
)
attention_probs
=
self
.
_dropout
(
attention_probs
)
# `context_layer` = [B, F, N, H]
# `context_layer` = [B, F, N, H]
return
tf
.
einsum
(
"BNFT,BTNH->BFNH"
,
attention_probs
,
value_tensor
),
cache
attention_output
=
tf
.
einsum
(
"BNFT,BTNH->BFNH"
,
attention_probs
,
value_tensor
)
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
,
cache
official/nlp/modeling/layers/attention_test.py
View file @
330b34fe
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -30,34 +31,44 @@ from official.nlp.modeling.layers import attention
...
@@ -30,34 +31,44 @@ from official.nlp.modeling.layers import attention
@
keras_parameterized
.
run_all_keras_modes
@
keras_parameterized
.
run_all_keras_modes
class
MultiHeadAttentionTest
(
keras_parameterized
.
TestCase
):
class
MultiHeadAttentionTest
(
keras_parameterized
.
TestCase
):
def
test_non_masked_attention
(
self
):
@
parameterized
.
named_parameters
(
(
"key_value_same_proj"
,
None
,
None
,
[
40
,
80
]),
(
"key_value_different_proj"
,
32
,
60
,
[
40
,
60
]),
)
def
test_non_masked_attention
(
self
,
value_size
,
output_shape
,
output_dims
):
"""Test that the attention layer can be created without a mask tensor."""
"""Test that the attention layer can be created without a mask tensor."""
test_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
12
,
head_size
=
64
)
test_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
12
,
key_size
=
64
,
value_size
=
value_size
,
output_shape
=
output_shape
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
to_tensor
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
value
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
output
=
test_layer
([
from_tensor
,
to_tensor
])
output
=
test_layer
([
query
,
value
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
12
,
64
]
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
]
+
output_dims
)
def
test_non_masked_self_attention
(
self
):
def
test_non_masked_self_attention
(
self
):
"""Test with one input (self-attenntion) and no mask tensor."""
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
12
,
head
_size
=
64
)
test_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
12
,
key
_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
([
from_tensor
,
from_tensor
])
output
=
test_layer
([
query
,
query
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
12
,
64
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
def
test_masked_attention
(
self
):
@
parameterized
.
parameters
(
True
,
False
)
def
test_masked_attention
(
self
,
use_bias
):
"""Test with a mask tensor."""
"""Test with a mask tensor."""
test_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
2
,
head_size
=
2
)
test_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
2
,
key_size
=
2
,
use_bias
=
use_bias
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
query
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
to_tensor
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
value
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
output
=
test_layer
([
from_tensor
,
to_tensor
,
mask_tensor
]
)
output
=
test_layer
([
query
,
value
]
,
mask_tensor
)
# Create a model containing the test layer.
# Create a model containing the test layer.
model
=
tf
.
keras
.
Model
([
from_tensor
,
to_tensor
,
mask_tensor
],
output
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
mask_tensor
],
output
)
# Generate data for the input (non-mask) tensors.
# Generate data for the input (non-mask) tensors.
from_data
=
10
*
np
.
random
.
random_sample
((
3
,
4
,
8
))
from_data
=
10
*
np
.
random
.
random_sample
((
3
,
4
,
8
))
...
@@ -76,16 +87,28 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
...
@@ -76,16 +87,28 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# same.
# same.
self
.
assertNotAllClose
(
masked_output_data
,
unmasked_output_data
)
self
.
assertNotAllClose
(
masked_output_data
,
unmasked_output_data
)
# Tests the layer with three inputs: Q, K, V.
key
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
output
=
test_layer
([
query
,
value
,
key
],
mask_tensor
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
key
,
mask_tensor
],
output
)
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
to_data
,
mask_data
])
unmasked_output_data
=
model
.
predict
(
[
from_data
,
to_data
,
to_data
,
null_mask_data
])
# Because one data is masked and one is not, the outputs should not be the
# same.
self
.
assertNotAllClose
(
masked_output_data
,
unmasked_output_data
)
def
test_initializer
(
self
):
def
test_initializer
(
self
):
"""Test with a specified initializer."""
"""Test with a specified initializer."""
test_layer
=
attention
.
MultiHeadAttention
(
test_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
12
,
num_heads
=
12
,
head
_size
=
64
,
key
_size
=
64
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
([
from_tensor
,
from_tensor
])
output
=
test_layer
([
query
,
query
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
12
,
64
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
def
_create_cache
(
batch_size
,
init_decode_length
,
num_heads
,
head_size
):
def
_create_cache
(
batch_size
,
init_decode_length
,
num_heads
,
head_size
):
...
@@ -112,7 +135,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
...
@@ -112,7 +135,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
init_decode_length
=
0
init_decode_length
=
0
# Directly tests the keras layer.
# Directly tests the keras layer.
cache
=
_create_cache
(
batch_size
,
init_decode_length
,
num_heads
,
head_size
)
cache
=
_create_cache
(
batch_size
,
init_decode_length
,
num_heads
,
head_size
)
layer
=
attention
.
CachedAttention
(
num_heads
=
num_heads
,
head
_size
=
head_size
)
layer
=
attention
.
CachedAttention
(
num_heads
=
num_heads
,
key
_size
=
head_size
)
# Generate data for the input (non-mask) tensors.
# Generate data for the input (non-mask) tensors.
from_data
=
tf
.
zeros
((
batch_size
,
from_seq_length
,
8
),
dtype
=
np
.
float32
)
from_data
=
tf
.
zeros
((
batch_size
,
from_seq_length
,
8
),
dtype
=
np
.
float32
)
...
@@ -121,12 +144,12 @@ class CachedAttentionTest(keras_parameterized.TestCase):
...
@@ -121,12 +144,12 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data
=
np
.
random
.
randint
(
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
from_seq_length
,
from_seq_length
))
2
,
size
=
(
batch_size
,
from_seq_length
,
from_seq_length
))
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
,
mask_data
,
cache
])
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
,
mask_data
,
cache
])
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
2
,
2
))
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
8
))
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
# Tests inputs without cache.
# Tests inputs without cache.
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
,
mask_data
])
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
,
mask_data
])
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
2
,
2
))
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
8
))
self
.
assertIsNone
(
cache
)
self
.
assertIsNone
(
cache
)
def
test_padded_decode
(
self
):
def
test_padded_decode
(
self
):
...
@@ -139,7 +162,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
...
@@ -139,7 +162,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# Directly tests the keras layer.
# Directly tests the keras layer.
cache
=
_create_cache
(
batch_size
,
init_decode_length
,
num_heads
,
head_size
)
cache
=
_create_cache
(
batch_size
,
init_decode_length
,
num_heads
,
head_size
)
layer
=
attention
.
CachedAttention
(
num_heads
=
num_heads
,
head
_size
=
head_size
)
layer
=
attention
.
CachedAttention
(
num_heads
=
num_heads
,
key
_size
=
head_size
)
# Generate data for the input (non-mask) tensors.
# Generate data for the input (non-mask) tensors.
from_data
=
tf
.
zeros
((
batch_size
,
from_seq_length
,
8
),
dtype
=
np
.
float32
)
from_data
=
tf
.
zeros
((
batch_size
,
from_seq_length
,
8
),
dtype
=
np
.
float32
)
...
@@ -149,7 +172,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
...
@@ -149,7 +172,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# Testing the invocation directly as Keras cannot consume inputs correctly.
# Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
,
mask_data
,
cache
],
masked_output_data
,
cache
=
layer
([
from_data
,
from_data
,
mask_data
,
cache
],
decode_loop_step
=
decode_loop_step
)
decode_loop_step
=
decode_loop_step
)
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
2
,
2
))
self
.
assertEqual
(
masked_output_data
.
shape
,
(
3
,
4
,
8
))
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
self
.
assertEqual
(
cache
[
"value"
].
shape
,
(
3
,
4
,
2
,
2
))
...
...
official/nlp/modeling/layers/rezero_transformer.py
View file @
330b34fe
...
@@ -108,7 +108,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -108,7 +108,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
num_heads
=
self
.
_num_heads
,
head
_size
=
self
.
_attention_head_size
,
key
_size
=
self
.
_attention_head_size
,
dropout_rate
=
self
.
_attention_dropout_rate
,
dropout_rate
=
self
.
_attention_dropout_rate
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
...
@@ -118,17 +118,6 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -118,17 +118,6 @@ class ReZeroTransformer(tf.keras.layers.Layer):
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"self_attention"
)
name
=
"self_attention"
)
self
.
_attention_output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
hidden_size
,
num_summed_dimensions
=
2
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"self_attention_output"
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
if
self
.
_use_layer_norm
:
if
self
.
_use_layer_norm
:
# Use float32 in layernorm for numeric stability.
# Use float32 in layernorm for numeric stability.
...
@@ -218,11 +207,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -218,11 +207,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_inputs
=
[
input_tensor
,
input_tensor
]
attention_inputs
=
[
input_tensor
,
input_tensor
]
if
attention_mask
is
not
None
:
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_inputs
.
append
(
attention_mask
)
attention_output
=
self
.
_attention_layer
(
attention_inputs
)
attention_output
=
self
.
_attention_output_dense
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
input_tensor
+
self
.
_rezero_a
*
attention_output
attention_output
=
input_tensor
+
self
.
_rezero_a
*
attention_output
if
self
.
_use_layer_norm
:
if
self
.
_use_layer_norm
:
...
...
official/nlp/modeling/layers/talking_heads_attention.py
View file @
330b34fe
...
@@ -31,8 +31,10 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
...
@@ -31,8 +31,10 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
Arguments:
Arguments:
num_heads: Number of attention heads.
num_heads: Number of attention heads.
head_size: Size of each attention head.
key_size: Size of each attention head.
dropout: Dropout probability.
dropout_rate: Dropout probability.
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
kernel_initializer: Initializer for dense layer kernels.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
kernel_regularizer: Regularizer for dense layer kernels.
...
@@ -44,8 +46,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
...
@@ -44,8 +46,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
def
__init__
(
self
,
def
__init__
(
self
,
num_heads
,
num_heads
,
head
_size
,
key
_size
,
dropout_rate
=
0.0
,
dropout_rate
=
0.0
,
output_shape
=
None
,
kernel_initializer
=
"glorot_uniform"
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
kernel_regularizer
=
None
,
...
@@ -56,8 +59,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
...
@@ -56,8 +59,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
**
kwargs
):
**
kwargs
):
super
(
TalkingHeadsAttention
,
self
).
__init__
(
**
kwargs
)
super
(
TalkingHeadsAttention
,
self
).
__init__
(
**
kwargs
)
self
.
_num_heads
=
num_heads
self
.
_num_heads
=
num_heads
self
.
_
head
_size
=
head
_size
self
.
_
key
_size
=
key
_size
self
.
_dropout_rate
=
dropout_rate
self
.
_dropout_rate
=
dropout_rate
self
.
_output_shape
=
output_shape
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
...
@@ -66,7 +70,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
...
@@ -66,7 +70,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_query_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_query_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_
head
_size
),
output_shape
=
(
self
.
_num_heads
,
self
.
_
key
_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
...
@@ -77,7 +81,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
...
@@ -77,7 +81,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
name
=
"query"
)
name
=
"query"
)
self
.
_key_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_key_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_
head
_size
),
output_shape
=
(
self
.
_num_heads
,
self
.
_
key
_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
...
@@ -88,7 +92,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
...
@@ -88,7 +92,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
name
=
"key"
)
name
=
"key"
)
self
.
_value_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_value_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_
head
_size
),
output_shape
=
(
self
.
_num_heads
,
self
.
_
key
_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
...
@@ -103,7 +107,22 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
...
@@ -103,7 +107,22 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
self
.
_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
super
(
TalkingHeadsAttention
,
self
).
build
(
input_shape
)
if
self
.
_output_shape
:
output_shape
=
self
.
_output_shape
else
:
input_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
output_shape
=
input_shape
[
-
1
]
self
.
_output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
output_shape
,
num_summed_dimensions
=
2
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"attention_output"
)
self
.
_pre_softmax_weight
=
self
.
add_weight
(
self
.
_pre_softmax_weight
=
self
.
add_weight
(
"pre_softmax_weight"
,
"pre_softmax_weight"
,
shape
=
(
self
.
_num_heads
,
self
.
_num_heads
),
shape
=
(
self
.
_num_heads
,
self
.
_num_heads
),
...
@@ -120,13 +139,14 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
...
@@ -120,13 +139,14 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
constraint
=
self
.
_kernel_constraint
,
constraint
=
self
.
_kernel_constraint
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
trainable
=
True
)
trainable
=
True
)
super
(
TalkingHeadsAttention
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
def
get_config
(
self
):
config
=
{
config
=
{
"num_heads"
:
"num_heads"
:
self
.
_num_heads
,
self
.
_num_heads
,
"
head
_size"
:
"
key
_size"
:
self
.
_
head
_size
,
self
.
_
key
_size
,
"dropout_rate"
:
"dropout_rate"
:
self
.
_dropout_rate
,
self
.
_dropout_rate
,
"kernel_initializer"
:
"kernel_initializer"
:
...
@@ -147,10 +167,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
...
@@ -147,10 +167,9 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
base_config
=
super
(
TalkingHeadsAttention
,
self
).
get_config
()
base_config
=
super
(
TalkingHeadsAttention
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
,
attention_mask
=
None
):
from_tensor
=
inputs
[
0
]
from_tensor
=
inputs
[
0
]
to_tensor
=
inputs
[
1
]
to_tensor
=
inputs
[
1
]
attention_mask
=
inputs
[
2
]
if
len
(
inputs
)
==
3
else
None
# Scalar dimensions referenced here:
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# B = batch size (number of sequences)
...
@@ -171,7 +190,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
...
@@ -171,7 +190,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
# attention scores.
# attention scores.
attention_scores
=
tf
.
einsum
(
"BTNH,BFNH->BNFT"
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
einsum
(
"BTNH,BFNH->BNFT"
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_
head
_size
)))
1.0
/
math
.
sqrt
(
float
(
self
.
_
key
_size
)))
# Apply talking heads before softmax.
# Apply talking heads before softmax.
attention_scores
=
tf
.
einsum
(
"BNFT,NL->BLFT"
,
attention_scores
,
attention_scores
=
tf
.
einsum
(
"BNFT,NL->BLFT"
,
attention_scores
,
...
@@ -190,4 +209,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
...
@@ -190,4 +209,7 @@ class TalkingHeadsAttention(tf.keras.layers.Layer):
attention_probs
=
self
.
_dropout
(
attention_probs
)
attention_probs
=
self
.
_dropout
(
attention_probs
)
# `context_layer` = [B, F, N, H]
# `context_layer` = [B, F, N, H]
return
tf
.
einsum
(
"BNFT,BTNH->BFNH"
,
attention_probs
,
value_tensor
)
attention_output
=
tf
.
einsum
(
"BNFT,BTNH->BFNH"
,
attention_probs
,
value_tensor
)
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
official/nlp/modeling/layers/talking_heads_attention_test.py
View file @
330b34fe
...
@@ -33,31 +33,31 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
...
@@ -33,31 +33,31 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
def
test_non_masked_attention
(
self
):
def
test_non_masked_attention
(
self
):
"""Test that the attention layer can be created without a mask tensor."""
"""Test that the attention layer can be created without a mask tensor."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
12
,
head
_size
=
64
)
num_heads
=
12
,
key
_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
to_tensor
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
to_tensor
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
output
=
test_layer
([
from_tensor
,
to_tensor
])
output
=
test_layer
([
from_tensor
,
to_tensor
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
12
,
64
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
def
test_non_masked_self_attention
(
self
):
def
test_non_masked_self_attention
(
self
):
"""Test with one input (self-attenntion) and no mask tensor."""
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
12
,
head
_size
=
64
)
num_heads
=
12
,
key
_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
([
from_tensor
,
from_tensor
])
output
=
test_layer
([
from_tensor
,
from_tensor
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
12
,
64
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
def
test_masked_attention
(
self
):
def
test_masked_attention
(
self
):
"""Test with a mask tensor."""
"""Test with a mask tensor."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
2
,
head
_size
=
2
)
num_heads
=
2
,
key
_size
=
2
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
to_tensor
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
to_tensor
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
output
=
test_layer
([
from_tensor
,
to_tensor
,
mask_tensor
]
)
output
=
test_layer
([
from_tensor
,
to_tensor
]
,
mask_tensor
)
# Create a model containing the test layer.
# Create a model containing the test layer.
model
=
tf
.
keras
.
Model
([
from_tensor
,
to_tensor
,
mask_tensor
],
output
)
model
=
tf
.
keras
.
Model
([
from_tensor
,
to_tensor
,
mask_tensor
],
output
)
...
@@ -83,12 +83,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
...
@@ -83,12 +83,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
"""Test with a specified initializer."""
"""Test with a specified initializer."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
12
,
num_heads
=
12
,
head
_size
=
64
,
key
_size
=
64
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
([
from_tensor
,
from_tensor
])
output
=
test_layer
([
from_tensor
,
from_tensor
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
12
,
64
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
official/nlp/modeling/layers/transformer.py
View file @
330b34fe
...
@@ -102,7 +102,7 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -102,7 +102,7 @@ class Transformer(tf.keras.layers.Layer):
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
num_heads
=
self
.
_num_heads
,
head
_size
=
self
.
_attention_head_size
,
key
_size
=
self
.
_attention_head_size
,
dropout_rate
=
self
.
_attention_dropout_rate
,
dropout_rate
=
self
.
_attention_dropout_rate
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
...
@@ -112,17 +112,11 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -112,17 +112,11 @@ class Transformer(tf.keras.layers.Layer):
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"self_attention"
)
name
=
"self_attention"
)
self
.
_attention_output_dense
=
dense_einsum
.
DenseEinsum
(
# TODO(hongkuny): Remove when checkpoint backward compatibility is resolved.
output_shape
=
hidden_size
,
# pylint: disable=protected-access
num_summed_dimensions
=
2
,
self
.
_attention_layer
.
build
([
input_tensor_shape
])
kernel_initializer
=
self
.
_kernel_initializer
,
self
.
_attention_output_dense
=
self
.
_attention_layer
.
_output_dense
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"self_attention_output"
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
# Use float32 in layernorm for numeric stability.
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
# It is probably safe in mixed_float16, but we haven't validated this yet.
...
@@ -200,11 +194,7 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -200,11 +194,7 @@ class Transformer(tf.keras.layers.Layer):
attention_inputs
=
[
input_tensor
,
input_tensor
]
attention_inputs
=
[
input_tensor
,
input_tensor
]
if
attention_mask
is
not
None
:
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_inputs
.
append
(
attention_mask
)
attention_output
=
self
.
_attention_layer
(
attention_inputs
)
attention_output
=
self
.
_attention_output_dense
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
attention_output
)
attention_output
)
...
...
official/nlp/modeling/layers/transformer_scaffold.py
View file @
330b34fe
...
@@ -118,7 +118,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
...
@@ -118,7 +118,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
if
self
.
_attention_cfg
is
None
:
if
self
.
_attention_cfg
is
None
:
attention_cfg
=
{
attention_cfg
=
{
"num_heads"
:
self
.
_num_heads
,
"num_heads"
:
self
.
_num_heads
,
"
head
_size"
:
self
.
_attention_head_size
,
"
key
_size"
:
self
.
_attention_head_size
,
"dropout_rate"
:
self
.
_attention_dropout_rate
,
"dropout_rate"
:
self
.
_attention_dropout_rate
,
"kernel_initializer"
:
self
.
_kernel_initializer
,
"kernel_initializer"
:
self
.
_kernel_initializer
,
"bias_initializer"
:
self
.
_bias_initializer
,
"bias_initializer"
:
self
.
_bias_initializer
,
...
@@ -219,11 +219,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
...
@@ -219,11 +219,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
attention_inputs
=
[
input_tensor
,
input_tensor
]
attention_inputs
=
[
input_tensor
,
input_tensor
]
if
attention_mask
is
not
None
:
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_inputs
.
append
(
attention_mask
)
attention_output
=
self
.
_attention_layer
(
attention_inputs
)
attention_output
=
self
.
_attention_output_dense
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
attention_output
)
attention_output
)
...
...
official/nlp/modeling/layers/transformer_scaffold_test.py
View file @
330b34fe
...
@@ -39,9 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
...
@@ -39,9 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super
(
ValidatedAttentionLayer
,
self
).
__init__
(
**
kwargs
)
super
(
ValidatedAttentionLayer
,
self
).
__init__
(
**
kwargs
)
self
.
list
=
call_list
self
.
list
=
call_list
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
,
attention_mask
=
None
):
self
.
list
.
append
(
True
)
self
.
list
.
append
(
True
)
return
super
(
ValidatedAttentionLayer
,
self
).
call
(
inputs
)
return
super
(
ValidatedAttentionLayer
,
self
).
call
(
inputs
,
attention_mask
=
attention_mask
)
def
get_config
(
self
):
def
get_config
(
self
):
config
=
super
(
ValidatedAttentionLayer
,
self
).
get_config
()
config
=
super
(
ValidatedAttentionLayer
,
self
).
get_config
()
...
@@ -65,7 +66,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -65,7 +66,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list
=
[]
call_list
=
[]
attention_layer_cfg
=
{
attention_layer_cfg
=
{
'num_heads'
:
10
,
'num_heads'
:
10
,
'
head
_size'
:
8
,
'
key
_size'
:
8
,
'call_list'
:
call_list
,
'call_list'
:
call_list
,
}
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
...
@@ -93,7 +94,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -93,7 +94,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list
=
[]
call_list
=
[]
attention_layer_cfg
=
{
attention_layer_cfg
=
{
'num_heads'
:
10
,
'num_heads'
:
10
,
'
head
_size'
:
8
,
'
key
_size'
:
8
,
'call_list'
:
call_list
,
'call_list'
:
call_list
,
}
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
...
@@ -122,7 +123,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -122,7 +123,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list
=
[]
call_list
=
[]
attention_layer_cfg
=
{
attention_layer_cfg
=
{
'num_heads'
:
10
,
'num_heads'
:
10
,
'
head
_size'
:
8
,
'
key
_size'
:
8
,
'call_list'
:
call_list
,
'call_list'
:
call_list
,
}
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
...
@@ -146,7 +147,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -146,7 +147,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list
=
[]
call_list
=
[]
attention_layer_cfg
=
{
attention_layer_cfg
=
{
'num_heads'
:
10
,
'num_heads'
:
10
,
'
head
_size'
:
8
,
'
key
_size'
:
8
,
'call_list'
:
call_list
,
'call_list'
:
call_list
,
}
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
...
@@ -181,7 +182,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -181,7 +182,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list
=
[]
call_list
=
[]
attention_layer_cfg
=
{
attention_layer_cfg
=
{
'num_heads'
:
10
,
'num_heads'
:
10
,
'
head
_size'
:
8
,
'
key
_size'
:
8
,
'call_list'
:
call_list
,
'call_list'
:
call_list
,
}
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
...
@@ -223,7 +224,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -223,7 +224,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list
=
[]
call_list
=
[]
attention_layer_cfg
=
{
attention_layer_cfg
=
{
'num_heads'
:
10
,
'num_heads'
:
10
,
'
head
_size'
:
8
,
'
key
_size'
:
8
,
'call_list'
:
call_list
,
'call_list'
:
call_list
,
}
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
...
@@ -264,7 +265,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -264,7 +265,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list
=
[]
call_list
=
[]
attention_layer_cfg
=
{
attention_layer_cfg
=
{
'num_heads'
:
10
,
'num_heads'
:
10
,
'
head
_size'
:
8
,
'
key
_size'
:
8
,
'call_list'
:
call_list
,
'call_list'
:
call_list
,
}
}
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
test_layer
=
transformer_scaffold
.
TransformerScaffold
(
...
@@ -292,7 +293,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -292,7 +293,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
call_list
=
[]
call_list
=
[]
attention_layer_cfg
=
{
attention_layer_cfg
=
{
'num_heads'
:
10
,
'num_heads'
:
10
,
'
head
_size'
:
8
,
'
key
_size'
:
8
,
'call_list'
:
call_list
,
'call_list'
:
call_list
,
'name'
:
'test_layer'
,
'name'
:
'test_layer'
,
}
}
...
...
official/nlp/nhnet/decoder.py
View file @
330b34fe
...
@@ -68,11 +68,11 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
...
@@ -68,11 +68,11 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"heads (%d)"
%
(
self
.
hidden_size
,
self
.
num_attention_heads
))
"heads (%d)"
%
(
self
.
hidden_size
,
self
.
num_attention_heads
))
self
.
attention_head_size
=
int
(
self
.
hidden_size
/
self
.
num_attention_heads
)
self
.
attention_head_size
=
int
(
self
.
hidden_size
/
self
.
num_attention_heads
)
def
build
(
self
,
unused_
input_shape
s
):
def
build
(
self
,
input_shape
):
# Self attention.
# Self attention.
self
.
self_attention
=
layers
.
CachedAttention
(
self
.
self_attention
=
layers
.
CachedAttention
(
num_heads
=
self
.
num_attention_heads
,
num_heads
=
self
.
num_attention_heads
,
head
_size
=
self
.
attention_head_size
,
key
_size
=
self
.
attention_head_size
,
dropout_rate
=
self
.
attention_probs_dropout_prob
,
dropout_rate
=
self
.
attention_probs_dropout_prob
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"self_attention"
)
name
=
"self_attention"
)
...
@@ -90,16 +90,18 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
...
@@ -90,16 +90,18 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
# Encoder-decoder attention.
# Encoder-decoder attention.
self
.
encdec_attention
=
self
.
_cross_attention_cls
(
self
.
encdec_attention
=
self
.
_cross_attention_cls
(
num_heads
=
self
.
num_attention_heads
,
num_heads
=
self
.
num_attention_heads
,
head
_size
=
self
.
attention_head_size
,
key
_size
=
self
.
attention_head_size
,
dropout_rate
=
self
.
attention_probs_dropout_prob
,
dropout_rate
=
self
.
attention_probs_dropout_prob
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"attention/encdec"
)
self
.
encdec_attention_output_dense
=
layers
.
DenseEinsum
(
output_shape
=
self
.
hidden_size
,
output_shape
=
self
.
hidden_size
,
num_summed_dimensions
=
2
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
name
=
"attention/encdec"
)
name
=
"attention/encdec_output"
)
# TODO(hongkuny): Remove when checkpoint backward compatibility is resolved.
# pylint: disable=protected-access
self
.
self_attention
.
build
(
input_shape
)
self
.
self_attention_output_dense
=
self
.
self_attention
.
_output_dense
self
.
encdec_attention
.
build
(
input_shape
)
self
.
encdec_attention_output_dense
=
self
.
encdec_attention
.
_output_dense
self
.
encdec_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
self
.
encdec_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
hidden_dropout_prob
)
rate
=
self
.
hidden_dropout_prob
)
self
.
encdec_attention_layer_norm
=
(
self
.
encdec_attention_layer_norm
=
(
...
@@ -123,14 +125,13 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
...
@@ -123,14 +125,13 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
hidden_dropout_prob
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
hidden_dropout_prob
)
self
.
output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
self
.
output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
)
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
)
super
(
TransformerDecoderBlock
,
self
).
build
(
unused_
input_shape
s
)
super
(
TransformerDecoderBlock
,
self
).
build
(
input_shape
)
def
common_layers_with_encoder
(
self
):
def
common_layers_with_encoder
(
self
):
"""Gets layer objects that can make a Transformer encoder block."""
"""Gets layer objects that can make a Transformer encoder block."""
return
[
return
[
self
.
self_attention
,
self
.
self_attention_output_dense
,
self
.
self_attention
,
self
.
self_attention_layer_norm
,
self
.
self_attention_layer_norm
,
self
.
intermediate_dense
,
self
.
intermediate_dense
,
self
.
output_dense
,
self
.
output_layer_norm
self
.
output_dense
,
self
.
output_layer_norm
]
]
def
call
(
self
,
inputs
,
cache
=
None
,
decode_loop_step
=
None
):
def
call
(
self
,
inputs
,
cache
=
None
,
decode_loop_step
=
None
):
...
@@ -152,18 +153,15 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
...
@@ -152,18 +153,15 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
]
]
self_attention_output
,
cache
=
self
.
self_attention
(
self_attention_output
,
cache
=
self
.
self_attention
(
self_attention_inputs
,
decode_loop_step
=
decode_loop_step
)
self_attention_inputs
,
decode_loop_step
=
decode_loop_step
)
self_attention_output
=
self
.
self_attention_output_dense
(
self_attention_output
)
self_attention_output
=
self
.
self_attention_dropout
(
self_attention_output
)
self_attention_output
=
self
.
self_attention_dropout
(
self_attention_output
)
self_attention_output
=
self
.
self_attention_layer_norm
(
self_attention_output
=
self
.
self_attention_layer_norm
(
input_tensor
+
self_attention_output
)
input_tensor
+
self_attention_output
)
cross_attn_inputs
=
[
self_attention_output
,
memory
,
attention_mask
]
cross_attn_inputs
=
[
self_attention_output
,
memory
]
if
self
.
multi_channel_cross_attention
:
if
self
.
multi_channel_cross_attention
:
# Accesses the 5-th input tensor for the doc-attention probabilities.
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs
.
append
(
inputs
[
-
1
])
cross_attn_inputs
.
append
(
inputs
[
-
1
])
attention_output
=
self
.
encdec_attention
(
cross_attn_inputs
)
attention_output
=
self
.
encdec_attention
(
cross_attn_inputs
,
attention_mask
)
attention_output
=
self
.
encdec_attention_output_dense
(
attention_output
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
+
attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
+
attention_output
)
attention_output
)
...
...
official/nlp/nhnet/multi_channel_attention.py
View file @
330b34fe
...
@@ -98,24 +98,14 @@ class DocAttention(tf.keras.layers.Layer):
...
@@ -98,24 +98,14 @@ class DocAttention(tf.keras.layers.Layer):
class
MultiChannelAttention
(
layers
.
MultiHeadAttention
):
class
MultiChannelAttention
(
layers
.
MultiHeadAttention
):
"""Multi-channel Attention layer."""
"""Multi-channel Attention layer."""
def
__init__
(
self
,
num_heads
,
head
_size
,
**
kwargs
):
def
__init__
(
self
,
num_heads
,
key
_size
,
**
kwargs
):
super
(
MultiChannelAttention
,
self
).
__init__
(
num_heads
,
head
_size
,
**
kwargs
)
super
(
MultiChannelAttention
,
self
).
__init__
(
num_heads
,
key
_size
,
**
kwargs
)
self
.
_masked_softmax
=
layers
.
MaskedSoftmax
(
mask_expansion_axes
=
[
2
])
self
.
_masked_softmax
=
layers
.
MaskedSoftmax
(
mask_expansion_axes
=
[
2
])
def
compute_output_shape
(
self
,
input_shape
):
def
call
(
self
,
inputs
,
attention_mask
=
None
):
if
len
(
input_shape
)
!=
4
:
raise
ValueError
(
"Layer %s must have 4 input tensors."
%
self
.
name
)
from_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
batch
=
from_tensor_shape
[
0
]
from_tensor_length
=
from_tensor_shape
[
1
]
return
tf
.
TensorShape
(
(
batch
,
from_tensor_length
,
self
.
_num_heads
,
self
.
_head_size
))
def
call
(
self
,
inputs
):
from_tensor
=
inputs
[
0
]
from_tensor
=
inputs
[
0
]
to_tensor
=
inputs
[
1
]
to_tensor
=
inputs
[
1
]
attention_mask
=
inputs
[
2
]
doc_attention_probs
=
inputs
[
2
]
doc_attention_probs
=
inputs
[
3
]
# Scalar dimensions referenced here:
# Scalar dimensions referenced here:
# B = batch size (number of stories)
# B = batch size (number of stories)
...
@@ -137,7 +127,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
...
@@ -137,7 +127,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
# attention scores.
# attention scores.
attention_scores
=
tf
.
einsum
(
"BATNH,BFNH->BANFT"
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
einsum
(
"BATNH,BFNH->BANFT"
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_
head
_size
)))
1.0
/
math
.
sqrt
(
float
(
self
.
_
key
_size
)))
# Normalize the attention scores to probabilities.
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, A, N, F, T]
# `attention_probs` = [B, A, N, F, T]
...
@@ -150,4 +140,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
...
@@ -150,4 +140,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
# `context_layer` = [B, F, N, H]
# `context_layer` = [B, F, N, H]
context_layer
=
tf
.
einsum
(
"BANFT,BATNH->BAFNH"
,
attention_probs
,
context_layer
=
tf
.
einsum
(
"BANFT,BATNH->BAFNH"
,
attention_probs
,
value_tensor
)
value_tensor
)
return
tf
.
einsum
(
"BNFA,BAFNH->BFNH"
,
doc_attention_probs
,
context_layer
)
attention_output
=
tf
.
einsum
(
"BNFA,BAFNH->BFNH"
,
doc_attention_probs
,
context_layer
)
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
official/nlp/nhnet/multi_channel_attention_test.py
View file @
330b34fe
...
@@ -40,15 +40,15 @@ class MultiChannelAttentionTest(tf.test.TestCase):
...
@@ -40,15 +40,15 @@ class MultiChannelAttentionTest(tf.test.TestCase):
num_heads
=
2
num_heads
=
2
num_docs
=
5
num_docs
=
5
attention_layer
=
multi_channel_attention
.
MultiChannelAttention
(
attention_layer
=
multi_channel_attention
.
MultiChannelAttention
(
num_heads
,
head
_size
=
2
)
num_heads
,
key
_size
=
2
)
from_data
=
10
*
np
.
random
.
random_sample
((
3
,
4
,
8
))
from_data
=
10
*
np
.
random
.
random_sample
((
3
,
4
,
8
))
to_data
=
10
*
np
.
random
.
random_sample
((
3
,
num_docs
,
2
,
8
))
to_data
=
10
*
np
.
random
.
random_sample
((
3
,
num_docs
,
2
,
8
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
num_docs
,
4
,
2
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
num_docs
,
4
,
2
))
doc_probs
=
np
.
random
.
randint
(
doc_probs
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
num_heads
,
4
,
num_docs
)).
astype
(
float
)
2
,
size
=
(
3
,
num_heads
,
4
,
num_docs
)).
astype
(
float
)
outputs
=
attention_layer
([
from_data
,
to_data
,
mask_data
,
doc_probs
]
)
outputs
=
attention_layer
([
from_data
,
to_data
,
doc_probs
],
mask_data
)
self
.
assertEqual
(
outputs
.
shape
,
(
3
,
4
,
num_heads
,
2
))
self
.
assertEqual
(
outputs
.
shape
,
(
3
,
4
,
8
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
official/nlp/nhnet/utils.py
View file @
330b34fe
...
@@ -40,7 +40,6 @@ def get_test_params(cls=nhnet_configs.BERT2BERTConfig):
...
@@ -40,7 +40,6 @@ def get_test_params(cls=nhnet_configs.BERT2BERTConfig):
def
encoder_common_layers
(
transformer_block
):
def
encoder_common_layers
(
transformer_block
):
return
[
return
[
transformer_block
.
_attention_layer
,
transformer_block
.
_attention_layer
,
transformer_block
.
_attention_output_dense
,
transformer_block
.
_attention_layer_norm
,
transformer_block
.
_attention_layer_norm
,
transformer_block
.
_intermediate_dense
,
transformer_block
.
_output_dense
,
transformer_block
.
_intermediate_dense
,
transformer_block
.
_output_dense
,
transformer_block
.
_output_layer_norm
transformer_block
.
_output_layer_norm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment