Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
43f5340f
Commit
43f5340f
authored
Jul 08, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jul 08, 2020
Browse files
Migrate all DenseEinsum to tf.keras.experimental.EinsumDense
PiperOrigin-RevId: 320240466
parent
180f2607
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
88 additions
and
127 deletions
+88
-127
official/nlp/modeling/layers/README.md
official/nlp/modeling/layers/README.md
+0
-5
official/nlp/modeling/layers/dense_einsum.py
official/nlp/modeling/layers/dense_einsum.py
+5
-0
official/nlp/modeling/layers/multi_channel_attention.py
official/nlp/modeling/layers/multi_channel_attention.py
+14
-17
official/nlp/modeling/layers/rezero_transformer.py
official/nlp/modeling/layers/rezero_transformer.py
+20
-29
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+49
-76
No files found.
official/nlp/modeling/layers/README.md
View file @
43f5340f
...
...
@@ -3,11 +3,6 @@
Layers are the fundamental building blocks for NLP models. They can be used to
assemble new layers, networks, or models.
*
[
DenseEinsum
](
dense_einsum.py
)
implements a feedforward network using
tf.einsum. This layer contains the einsum op, the associated weight, and the
logic required to generate the einsum expression for the given
initialization parameters.
*
[
MultiHeadAttention
](
attention.py
)
implements an optionally masked attention
between query, key, value tensors as described in
[
"Attention Is All You Need"
](
https://arxiv.org/abs/1706.03762
)
. If
...
...
official/nlp/modeling/layers/dense_einsum.py
View file @
43f5340f
...
...
@@ -21,6 +21,8 @@ from __future__ import print_function
import
tensorflow
as
tf
from
tensorflow.python.util
import
deprecation
_CHR_IDX
=
[
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"f"
,
"g"
,
"h"
,
"i"
,
"j"
,
"k"
,
"l"
,
"m"
]
...
...
@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer):
`(batch_size, units)`.
"""
@
deprecation
.
deprecated
(
None
,
"DenseEinsum is deprecated. Please use "
"tf.keras.experimental.EinsumDense layer instead."
)
def
__init__
(
self
,
output_shape
,
num_summed_dimensions
=
1
,
...
...
official/nlp/modeling/layers/multi_channel_attention.py
View file @
43f5340f
...
...
@@ -26,7 +26,6 @@ import math
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
dense_einsum
from
official.nlp.modeling.layers
import
masked_softmax
...
...
@@ -67,28 +66,26 @@ class VotingAttention(tf.keras.layers.Layer):
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
def
build
(
self
,
unused_input_shapes
):
self
.
_query_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_head_size
),
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
dtype
=
self
.
dtype
,
name
=
"encdocatt_query"
)
self
.
_key_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
self
.
_num_heads
,
self
.
_head_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
dtype
=
self
.
dtype
,
name
=
"encdocatt_key"
)
bias_constraint
=
self
.
_bias_constraint
)
self
.
_query_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"BAE,ENH->BANH"
,
output_shape
=
(
None
,
self
.
_num_heads
,
self
.
_head_size
),
bias_axes
=
"NH"
,
name
=
"query"
,
**
common_kwargs
)
self
.
_key_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"BAE,ENH->BANH"
,
output_shape
=
(
None
,
self
.
_num_heads
,
self
.
_head_size
),
bias_axes
=
"NH"
,
name
=
"key"
,
**
common_kwargs
)
super
(
VotingAttention
,
self
).
build
(
unused_input_shapes
)
def
call
(
self
,
encoder_outputs
,
doc_attention_mask
):
...
...
official/nlp/modeling/layers/rezero_transformer.py
View file @
43f5340f
...
...
@@ -23,7 +23,6 @@ import gin
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
dense_einsum
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
...
...
@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
key_size
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"self_attention"
)
bias_constraint
=
self
.
_bias_constraint
)
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
key_size
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
name
=
"self_attention"
,
**
common_kwargs
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
if
self
.
_use_layer_norm
:
# Use float32 in layernorm for numeric stability.
...
...
@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
self
.
_intermediate_size
,
activation
=
None
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"intermediate"
)
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
self
.
_intermediate_size
),
bias_axes
=
"d"
,
name
=
"intermediate"
,
**
common_kwargs
)
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
if
policy
.
name
==
"mixed_bfloat16"
:
# bfloat16 causes BERT with the LAMB optimizer to not converge
...
...
@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
policy
=
tf
.
float32
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation
,
dtype
=
policy
)
self
.
_output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
hidden_size
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"output"
)
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
name
=
"output"
,
**
common_kwargs
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
if
self
.
_use_layer_norm
:
# Use float32 in layernorm for numeric stability.
...
...
official/nlp/modeling/layers/transformer.py
View file @
43f5340f
...
...
@@ -23,7 +23,6 @@ import gin
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
dense_einsum
from
official.nlp.modeling.layers
import
multi_channel_attention
from
official.nlp.modeling.layers.util
import
tf_function_if_eager
...
...
@@ -106,19 +105,20 @@ class Transformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
key_size
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"self_attention"
)
bias_constraint
=
self
.
_bias_constraint
)
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
key_size
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
name
=
"self_attention"
,
**
common_kwargs
)
# pylint: disable=protected-access
self
.
_attention_layer
.
build
([
input_tensor_shape
]
*
3
)
self
.
_attention_output_dense
=
self
.
_attention_layer
.
_output_dense
...
...
@@ -132,17 +132,12 @@ class Transformer(tf.keras.layers.Layer):
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
self
.
_intermediate_size
,
activation
=
None
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"intermediate"
)
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
self
.
_intermediate_size
),
bias_axes
=
"d"
,
name
=
"intermediate"
,
**
common_kwargs
)
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
if
policy
.
name
==
"mixed_bfloat16"
:
# bfloat16 causes BERT with the LAMB optimizer to not converge
...
...
@@ -151,16 +146,12 @@ class Transformer(tf.keras.layers.Layer):
policy
=
tf
.
float32
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation
,
dtype
=
policy
)
self
.
_output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
hidden_size
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"output"
)
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
name
=
"output"
,
**
common_kwargs
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
# Use float32 in layernorm for numeric stability.
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
...
...
@@ -312,30 +303,27 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
self
.
num_attention_heads
))
self
.
attention_head_size
=
int
(
hidden_size
/
self
.
num_attention_heads
)
# Self attention.
self
.
self_attention
=
attention
.
CachedAttention
(
num_heads
=
self
.
num_attention_heads
,
key_size
=
self
.
attention_head_size
,
dropout
=
self
.
attention_dropout_rate
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"self_attention"
)
self
.
self_attention_output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
hidden_size
,
num_summed_dimensions
=
2
,
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"self_attention_output"
)
bias_constraint
=
self
.
_bias_constraint
)
# Self attention.
self
.
self_attention
=
attention
.
CachedAttention
(
num_heads
=
self
.
num_attention_heads
,
key_size
=
self
.
attention_head_size
,
dropout
=
self
.
attention_dropout_rate
,
name
=
"self_attention"
,
**
common_kwargs
)
self
.
self_attention_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
name
=
"output"
,
**
common_kwargs
)
self
.
self_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
self_attention_layer_norm
=
(
...
...
@@ -347,14 +335,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
key_size
=
self
.
attention_head_size
,
dropout
=
self
.
attention_dropout_rate
,
output_shape
=
hidden_size
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"attention/encdec"
)
name
=
"attention/encdec"
,
**
common_kwargs
)
self
.
encdec_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
...
...
@@ -363,29 +345,20 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
name
=
"attention/encdec_output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
))
# Feed-forward projection.
self
.
intermediate_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
self
.
intermediate_size
,
activation
=
None
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"intermediate"
)
self
.
intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
self
.
intermediate_size
),
bias_axes
=
"d"
,
name
=
"intermediate"
,
**
common_kwargs
)
self
.
intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
intermediate_activation
)
self
.
output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
hidden_size
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"output"
)
self
.
output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
name
=
"output"
,
**
common_kwargs
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment