Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8aa44501
Commit
8aa44501
authored
Jun 21, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jun 21, 2020
Browse files
Internal change
PiperOrigin-RevId: 317596394
parent
4b0cec67
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
95 additions
and
37 deletions
+95
-37
official/nlp/modeling/layers/README.md
official/nlp/modeling/layers/README.md
+6
-2
official/nlp/modeling/layers/__init__.py
official/nlp/modeling/layers/__init__.py
+1
-1
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+82
-26
official/nlp/modeling/layers/transformer_test.py
official/nlp/modeling/layers/transformer_test.py
+2
-4
official/nlp/nhnet/decoder.py
official/nlp/nhnet/decoder.py
+4
-4
No files found.
official/nlp/modeling/layers/README.md
View file @
8aa44501
...
@@ -28,6 +28,10 @@ assemble new layers, networks, or models.
...
@@ -28,6 +28,10 @@ assemble new layers, networks, or models.
described in
described in
[
"Attention Is All You Need"
](
https://arxiv.org/abs/1706.03762
)
.
[
"Attention Is All You Need"
](
https://arxiv.org/abs/1706.03762
)
.
*
[
TransformerDecoderLayer
](
transformer.py
)
TransformerDecoderLayer is made up
of self multi-head attention, cross multi-head attention and
feedforward network.
*
[
ReZeroTransformer
](
rezero_transformer.py
)
implements Transformer with
*
[
ReZeroTransformer
](
rezero_transformer.py
)
implements Transformer with
ReZero described in
ReZero described in
[
"ReZero is All You Need: Fast Convergence at Large Depth"
](
https://arxiv.org/abs/2003.04887
)
.
[
"ReZero is All You Need: Fast Convergence at Large Depth"
](
https://arxiv.org/abs/2003.04887
)
.
...
@@ -49,8 +53,8 @@ assemble new layers, networks, or models.
...
@@ -49,8 +53,8 @@ assemble new layers, networks, or models.
should be masked), the output will have masked positions set to
should be masked), the output will have masked positions set to
approximately zero.
approximately zero.
*
[
`MaskedLM`
](
masked_lm.py
)
implements a masked language model. It assumes
the
*
[
`MaskedLM`
](
masked_lm.py
)
implements a masked language model. It assumes
embedding table variable is passed to it.
the
embedding table variable is passed to it.
*
[
ClassificationHead
](
cls_head.py
)
A pooling head over a sequence of
*
[
ClassificationHead
](
cls_head.py
)
A pooling head over a sequence of
embeddings, commonly used by classification tasks.
embeddings, commonly used by classification tasks.
...
...
official/nlp/modeling/layers/__init__.py
View file @
8aa44501
...
@@ -26,5 +26,5 @@ from official.nlp.modeling.layers.position_embedding import PositionEmbedding
...
@@ -26,5 +26,5 @@ from official.nlp.modeling.layers.position_embedding import PositionEmbedding
from
official.nlp.modeling.layers.rezero_transformer
import
ReZeroTransformer
from
official.nlp.modeling.layers.rezero_transformer
import
ReZeroTransformer
from
official.nlp.modeling.layers.self_attention_mask
import
SelfAttentionMask
from
official.nlp.modeling.layers.self_attention_mask
import
SelfAttentionMask
from
official.nlp.modeling.layers.talking_heads_attention
import
TalkingHeadsAttention
from
official.nlp.modeling.layers.talking_heads_attention
import
TalkingHeadsAttention
from
official.nlp.modeling.layers.transformer
import
Transformer
from
official.nlp.modeling.layers.transformer
import
*
from
official.nlp.modeling.layers.transformer_scaffold
import
TransformerScaffold
from
official.nlp.modeling.layers.transformer_scaffold
import
TransformerScaffold
official/nlp/modeling/layers/transformer.py
View file @
8aa44501
...
@@ -79,6 +79,7 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -79,6 +79,7 @@ class Transformer(tf.keras.layers.Layer):
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
...
@@ -247,57 +248,96 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -247,57 +248,96 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
(1) a multi-head self-attention mechanism.
(1) a multi-head self-attention mechanism.
(2) a encoder-decoder attention.
(2) a encoder-decoder attention.
(3) a positionwise fully connected feed-forward network.
(3) a positionwise fully connected feed-forward network.
Arguments:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
multi_channel_cross_attention: Whether to use `MultiChannelAttention` for
cross-attention between target sequences and source sequences.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
hidden_size
=
768
,
num_attention_heads
,
num_attention_heads
=
12
,
intermediate_size
,
intermediate_size
=
3072
,
intermediate_activation
,
intermediate_activation
=
"relu"
,
dropout_rate
=
0.0
,
hidden_dropout_prob
=
0.0
,
attention_dropout_rate
=
0.0
,
attention_probs_dropout_prob
=
0.0
,
initializer_range
=
0.02
,
multi_channel_cross_attention
=
False
,
multi_channel_cross_attention
=
False
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
**
kwargs
):
**
kwargs
):
super
(
TransformerDecoderLayer
,
self
).
__init__
(
**
kwargs
)
super
(
TransformerDecoderLayer
,
self
).
__init__
(
**
kwargs
)
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
intermediate_size
=
intermediate_size
self
.
intermediate_activation
=
tf
.
keras
.
activations
.
get
(
self
.
intermediate_activation
=
tf
.
keras
.
activations
.
get
(
intermediate_activation
)
intermediate_activation
)
self
.
hidden_
dropout_
prob
=
hidden_
dropout_
prob
self
.
dropout_
rate
=
dropout_
rate
self
.
attention_
probs_
dropout_
prob
=
attention_
probs_
dropout_
prob
self
.
attention_dropout_
rate
=
attention_dropout_
rate
self
.
multi_channel_cross_attention
=
multi_channel_cross_attention
self
.
multi_channel_cross_attention
=
multi_channel_cross_attention
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
stddev
=
initializer_range
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
"zeros"
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
if
self
.
multi_channel_cross_attention
:
if
self
.
multi_channel_cross_attention
:
self
.
_cross_attention_cls
=
multi_channel_attention
.
MultiChannelAttention
self
.
_cross_attention_cls
=
multi_channel_attention
.
MultiChannelAttention
else
:
else
:
self
.
_cross_attention_cls
=
attention
.
MultiHeadAttention
self
.
_cross_attention_cls
=
attention
.
MultiHeadAttention
if
self
.
hidden_size
%
self
.
num_attention_heads
!=
0
:
def
build
(
self
,
input_shape
):
target_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
if
len
(
target_tensor_shape
)
!=
3
:
raise
ValueError
(
"TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width]."
)
hidden_size
=
target_tensor_shape
[
2
]
if
hidden_size
%
self
.
num_attention_heads
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
self
.
hidden_size
,
self
.
num_attention_heads
))
"heads (%d)"
%
(
hidden_size
,
self
.
num_attention_heads
))
self
.
attention_head_size
=
int
(
self
.
hidden_size
/
self
.
num_attention_heads
)
self
.
attention_head_size
=
int
(
hidden_size
/
self
.
num_attention_heads
)
def
build
(
self
,
input_shape
):
# Self attention.
# Self attention.
self
.
self_attention
=
attention
.
CachedAttention
(
self
.
self_attention
=
attention
.
CachedAttention
(
num_heads
=
self
.
num_attention_heads
,
num_heads
=
self
.
num_attention_heads
,
key_size
=
self
.
attention_head_size
,
key_size
=
self
.
attention_head_size
,
dropout
=
self
.
attention_
probs_
dropout_
prob
,
dropout
=
self
.
attention_dropout_
rate
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"self_attention"
)
name
=
"self_attention"
)
self
.
self_attention_output_dense
=
dense_einsum
.
DenseEinsum
(
self
.
self_attention_output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
self
.
hidden_size
,
output_shape
=
hidden_size
,
num_summed_dimensions
=
2
,
num_summed_dimensions
=
2
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"self_attention_output"
)
name
=
"self_attention_output"
)
self
.
self_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
self
.
self_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
hidden_
dropout_
prob
)
rate
=
self
.
dropout_
rate
)
self
.
self_attention_layer_norm
=
(
self
.
self_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
))
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
))
...
@@ -305,13 +345,19 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -305,13 +345,19 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self
.
encdec_attention
=
self
.
_cross_attention_cls
(
self
.
encdec_attention
=
self
.
_cross_attention_cls
(
num_heads
=
self
.
num_attention_heads
,
num_heads
=
self
.
num_attention_heads
,
key_size
=
self
.
attention_head_size
,
key_size
=
self
.
attention_head_size
,
dropout
=
self
.
attention_
probs_
dropout_
prob
,
dropout
=
self
.
attention_dropout_
rate
,
output_shape
=
self
.
hidden_size
,
output_shape
=
hidden_size
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"attention/encdec"
)
name
=
"attention/encdec"
)
self
.
encdec_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
self
.
encdec_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
hidden_
dropout_
prob
)
rate
=
self
.
dropout_
rate
)
self
.
encdec_attention_layer_norm
=
(
self
.
encdec_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"attention/encdec_output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
))
name
=
"attention/encdec_output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
))
...
@@ -322,15 +368,25 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -322,15 +368,25 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activation
=
None
,
activation
=
None
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"intermediate"
)
name
=
"intermediate"
)
self
.
intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
intermediate_activation
)
self
.
intermediate_activation
)
self
.
output_dense
=
dense_einsum
.
DenseEinsum
(
self
.
output_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
self
.
hidden_size
,
output_shape
=
hidden_size
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"output"
)
name
=
"output"
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
hidden_
dropout_
prob
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_
rate
)
self
.
output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
self
.
output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
)
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
)
super
(
TransformerDecoderLayer
,
self
).
build
(
input_shape
)
super
(
TransformerDecoderLayer
,
self
).
build
(
input_shape
)
...
...
official/nlp/modeling/layers/transformer_test.py
View file @
8aa44501
...
@@ -233,13 +233,11 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
...
@@ -233,13 +233,11 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
num_attention_heads
=
2
num_attention_heads
=
2
hidden_size
=
16
hidden_size
=
16
decoder_block
=
transformer
.
TransformerDecoderLayer
(
decoder_block
=
transformer
.
TransformerDecoderLayer
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
32
,
intermediate_size
=
32
,
intermediate_activation
=
'relu'
,
intermediate_activation
=
'relu'
,
hidden_dropout_prob
=
0.1
,
dropout_rate
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
attention_dropout_rate
=
0.1
)
initializer_range
=
0.1
)
# Forward path.
# Forward path.
dummy_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
dummy_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
4
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
4
],
dtype
=
tf
.
float32
)
...
...
official/nlp/nhnet/decoder.py
View file @
8aa44501
...
@@ -60,13 +60,13 @@ class TransformerDecoder(tf.keras.layers.Layer):
...
@@ -60,13 +60,13 @@ class TransformerDecoder(tf.keras.layers.Layer):
for
i
in
range
(
self
.
num_hidden_layers
):
for
i
in
range
(
self
.
num_hidden_layers
):
self
.
layers
.
append
(
self
.
layers
.
append
(
transformer
.
TransformerDecoderLayer
(
transformer
.
TransformerDecoderLayer
(
hidden_size
=
self
.
hidden_size
,
num_attention_heads
=
self
.
num_attention_heads
,
num_attention_heads
=
self
.
num_attention_heads
,
intermediate_size
=
self
.
intermediate_size
,
intermediate_size
=
self
.
intermediate_size
,
intermediate_activation
=
self
.
intermediate_activation
,
intermediate_activation
=
self
.
intermediate_activation
,
hidden_dropout_prob
=
self
.
hidden_dropout_prob
,
dropout_rate
=
self
.
hidden_dropout_prob
,
attention_probs_dropout_prob
=
self
.
attention_probs_dropout_prob
,
attention_dropout_rate
=
self
.
attention_probs_dropout_prob
,
initializer_range
=
self
.
initializer_range
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
self
.
initializer_range
),
multi_channel_cross_attention
=
self
.
multi_channel_cross_attention
,
multi_channel_cross_attention
=
self
.
multi_channel_cross_attention
,
name
=
(
"layer_%d"
%
i
)))
name
=
(
"layer_%d"
%
i
)))
super
(
TransformerDecoder
,
self
).
build
(
unused_input_shapes
)
super
(
TransformerDecoder
,
self
).
build
(
unused_input_shapes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment