Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
01c4ee81
Commit
01c4ee81
authored
Apr 16, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Apr 16, 2020
Browse files
Internal change
PiperOrigin-RevId: 306986791
parent
dbd74385
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
291 additions
and
0 deletions
+291
-0
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+1
-0
official/nlp/bert/run_squad_helper.py
official/nlp/bert/run_squad_helper.py
+1
-0
official/nlp/modeling/layers/__init__.py
official/nlp/modeling/layers/__init__.py
+1
-0
official/nlp/modeling/layers/talking_heads_attention.py
official/nlp/modeling/layers/talking_heads_attention.py
+193
-0
official/nlp/modeling/layers/talking_heads_attention_test.py
official/nlp/modeling/layers/talking_heads_attention_test.py
+95
-0
No files found.
official/nlp/bert/run_classifier.py
View file @
01c4ee81
...
@@ -56,6 +56,7 @@ flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
...
@@ -56,6 +56,7 @@ flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags
.
DEFINE_integer
(
'eval_batch_size'
,
32
,
'Batch size for evaluation.'
)
flags
.
DEFINE_integer
(
'eval_batch_size'
,
32
,
'Batch size for evaluation.'
)
common_flags
.
define_common_bert_flags
()
common_flags
.
define_common_bert_flags
()
common_flags
.
define_gin_flags
()
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
...
official/nlp/bert/run_squad_helper.py
View file @
01c4ee81
...
@@ -88,6 +88,7 @@ def define_common_squad_flags():
...
@@ -88,6 +88,7 @@ def define_common_squad_flags():
'another.'
)
'another.'
)
common_flags
.
define_common_bert_flags
()
common_flags
.
define_common_bert_flags
()
common_flags
.
define_gin_flags
()
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
...
official/nlp/modeling/layers/__init__.py
View file @
01c4ee81
...
@@ -20,5 +20,6 @@ from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
...
@@ -20,5 +20,6 @@ from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from
official.nlp.modeling.layers.position_embedding
import
PositionEmbedding
from
official.nlp.modeling.layers.position_embedding
import
PositionEmbedding
from
official.nlp.modeling.layers.rezero_transformer
import
ReZeroTransformer
from
official.nlp.modeling.layers.rezero_transformer
import
ReZeroTransformer
from
official.nlp.modeling.layers.self_attention_mask
import
SelfAttentionMask
from
official.nlp.modeling.layers.self_attention_mask
import
SelfAttentionMask
from
official.nlp.modeling.layers.talking_heads_attention
import
TalkingHeadsAttention
from
official.nlp.modeling.layers.transformer
import
Transformer
from
official.nlp.modeling.layers.transformer
import
Transformer
from
official.nlp.modeling.layers.transformer_scaffold
import
TransformerScaffold
from
official.nlp.modeling.layers.transformer_scaffold
import
TransformerScaffold
official/nlp/modeling/layers/talking_heads_attention.py
0 → 100644
View file @
01c4ee81
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Talking Head Attention layer."""
# pylint: disable=g-classes-have-attributes
import
math
import
gin
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
dense_einsum
from
official.nlp.modeling.layers
import
masked_softmax
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
gin
.
configurable
class
TalkingHeadsAttention
(
tf
.
keras
.
layers
.
Layer
):
"""Implements Talking-Heads Attention.
https://arxiv.org/abs/2003.02436
Arguments:
num_heads: Number of attention heads.
head_size: Size of each attention head.
dropout: Dropout probability.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
def
__init__
(
self
,
num_heads
,
head_size
,
dropout_rate
=
0.0
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
**
kwargs
):
super
(
TalkingHeadsAttention
,
self
).
__init__
(
**
kwargs
)
self
.
_num_heads
=
num_heads
self
.
_head_size
=
head_size
self
.
_dropout_rate
=
dropout_rate
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_query_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_head_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"query"
)
self
.
_key_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_head_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"key"
)
self
.
_value_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_head_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"value"
)
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
1
])
self
.
_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
def
build
(
self
,
input_shape
):
super
(
TalkingHeadsAttention
,
self
).
build
(
input_shape
)
self
.
_pre_softmax_weight
=
self
.
add_weight
(
"pre_softmax_weight"
,
shape
=
(
self
.
_num_heads
,
self
.
_num_heads
),
initializer
=
self
.
_kernel_initializer
,
regularizer
=
self
.
_kernel_regularizer
,
constraint
=
self
.
_kernel_constraint
,
dtype
=
self
.
dtype
,
trainable
=
True
)
self
.
_post_softmax_weight
=
self
.
add_weight
(
"post_softmax_weight"
,
shape
=
(
self
.
_num_heads
,
self
.
_num_heads
),
initializer
=
self
.
_kernel_initializer
,
regularizer
=
self
.
_kernel_regularizer
,
constraint
=
self
.
_kernel_constraint
,
dtype
=
self
.
dtype
,
trainable
=
True
)
def
get_config
(
self
):
config
=
{
"num_heads"
:
self
.
_num_heads
,
"head_size"
:
self
.
_head_size
,
"dropout_rate"
:
self
.
_dropout_rate
,
"kernel_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
"bias_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_bias_initializer
),
"kernel_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_kernel_regularizer
),
"bias_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_bias_regularizer
),
"activity_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_activity_regularizer
),
"kernel_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
)
}
base_config
=
super
(
TalkingHeadsAttention
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
from_tensor
=
inputs
[
0
]
to_tensor
=
inputs
[
1
]
attention_mask
=
inputs
[
2
]
if
len
(
inputs
)
==
3
else
None
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = L = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor
=
self
.
_query_dense
(
from_tensor
)
# `key_tensor` = [B, T, N, H]
key_tensor
=
self
.
_key_dense
(
to_tensor
)
# `value_tensor` = [B, T, N, H]
value_tensor
=
self
.
_value_dense
(
to_tensor
)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores
=
tf
.
einsum
(
"BTNH,BFNH->BNFT"
,
key_tensor
,
query_tensor
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
self
.
_head_size
)))
# Apply talking heads before softmax.
attention_scores
=
tf
.
einsum
(
"BNFT,NL->BLFT"
,
attention_scores
,
self
.
_pre_softmax_weight
)
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
attention_probs
=
self
.
_masked_softmax
([
attention_scores
,
attention_mask
])
# Apply talking heads after softmax.
attention_probs
=
tf
.
einsum
(
"BNFT,NL->BLFT"
,
attention_probs
,
self
.
_post_softmax_weight
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs
=
self
.
_dropout
(
attention_probs
)
# `context_layer` = [B, F, N, H]
return
tf
.
einsum
(
"BNFT,BTNH->BFNH"
,
attention_probs
,
value_tensor
)
official/nlp/modeling/layers/talking_heads_attention_test.py
0 → 100644
View file @
01c4ee81
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the attention layer."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.modeling.layers
import
talking_heads_attention
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
class
MultiHeadAttentionTest
(
keras_parameterized
.
TestCase
):
def
test_non_masked_attention
(
self
):
"""Test that the attention layer can be created without a mask tensor."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
12
,
head_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
to_tensor
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
output
=
test_layer
([
from_tensor
,
to_tensor
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
12
,
64
])
def
test_non_masked_self_attention
(
self
):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
12
,
head_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
([
from_tensor
,
from_tensor
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
12
,
64
])
def
test_masked_attention
(
self
):
"""Test with a mask tensor."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
2
,
head_size
=
2
)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
to_tensor
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
output
=
test_layer
([
from_tensor
,
to_tensor
,
mask_tensor
])
# Create a model containing the test layer.
model
=
tf
.
keras
.
Model
([
from_tensor
,
to_tensor
,
mask_tensor
],
output
)
# Generate data for the input (non-mask) tensors.
from_data
=
10
*
np
.
random
.
random_sample
((
3
,
4
,
8
))
to_data
=
10
*
np
.
random
.
random_sample
((
3
,
2
,
8
))
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
4
,
2
))
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
mask_data
])
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data
=
np
.
ones
((
3
,
4
,
2
))
unmasked_output_data
=
model
.
predict
([
from_data
,
to_data
,
null_mask_data
])
# Because one data is masked and one is not, the outputs should not be the
# same.
self
.
assertNotAllClose
(
masked_output_data
,
unmasked_output_data
)
def
test_initializer
(
self
):
"""Test with a specified initializer."""
test_layer
=
talking_heads_attention
.
TalkingHeadsAttention
(
num_heads
=
12
,
head_size
=
64
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
([
from_tensor
,
from_tensor
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
12
,
64
])
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment