Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
dc03c043
Commit
dc03c043
authored
Aug 03, 2020
by
xinliupitt
Browse files
intermediate dropout
parent
e93afea8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
6 deletions
+40
-6
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+32
-2
official/nlp/modeling/layers/transformer_test.py
official/nlp/modeling/layers/transformer_test.py
+8
-4
No files found.
official/nlp/modeling/layers/transformer.py
View file @
dc03c043
...
@@ -55,6 +55,10 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -55,6 +55,10 @@ class Transformer(tf.keras.layers.Layer):
layers. If set False, output of attention and intermediate dense layers is
layers. If set False, output of attention and intermediate dense layers is
normalized.
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer. If
larger than 0.0, intermediate_dropout_layer is created and used after
intermediate_activation_layer. Otherwise, intermediate_dropout_layer is
None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -74,6 +78,7 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -74,6 +78,7 @@ class Transformer(tf.keras.layers.Layer):
use_bias
=
True
,
use_bias
=
True
,
norm_first
=
False
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
norm_epsilon
=
1e-12
,
intermediate_dropout
=
0.0
,
**
kwargs
):
**
kwargs
):
super
(
Transformer
,
self
).
__init__
(
**
kwargs
)
super
(
Transformer
,
self
).
__init__
(
**
kwargs
)
...
@@ -93,6 +98,7 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -93,6 +98,7 @@ class Transformer(tf.keras.layers.Layer):
self
.
_use_bias
=
use_bias
self
.
_use_bias
=
use_bias
self
.
_norm_first
=
norm_first
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
self
.
_norm_epsilon
=
norm_epsilon
self
.
_intermediate_dropout
=
intermediate_dropout
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
input_tensor
=
input_shape
[
0
]
if
len
(
input_shape
)
==
2
else
input_shape
input_tensor
=
input_shape
[
0
]
if
len
(
input_shape
)
==
2
else
input_shape
...
@@ -155,6 +161,11 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -155,6 +161,11 @@ class Transformer(tf.keras.layers.Layer):
policy
=
tf
.
float32
policy
=
tf
.
float32
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation
,
dtype
=
policy
)
self
.
_intermediate_activation
,
dtype
=
policy
)
if
self
.
_intermediate_dropout
>
0.0
:
self
.
intermediate_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_intermediate_dropout
)
else
:
self
.
intermediate_dropout_layer
=
None
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
"abc,cd->abd"
,
output_shape
=
(
None
,
hidden_size
),
output_shape
=
(
None
,
hidden_size
),
...
@@ -204,7 +215,9 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -204,7 +215,9 @@ class Transformer(tf.keras.layers.Layer):
"norm_first"
:
"norm_first"
:
self
.
_norm_first
,
self
.
_norm_first
,
"norm_epsilon"
:
"norm_epsilon"
:
self
.
_norm_epsilon
self
.
_norm_epsilon
,
"intermediate_dropout"
:
self
.
_intermediate_dropout
}
}
base_config
=
super
(
Transformer
,
self
).
get_config
()
base_config
=
super
(
Transformer
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
@@ -238,6 +251,8 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -238,6 +251,8 @@ class Transformer(tf.keras.layers.Layer):
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_activation_layer
(
intermediate_output
=
self
.
_intermediate_activation_layer
(
intermediate_output
)
intermediate_output
)
if
self
.
intermediate_dropout_layer
:
intermediate_output
=
self
.
intermediate_dropout_layer
(
intermediate_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
# During mixed precision training, attention_output is from layer norm and
# During mixed precision training, attention_output is from layer norm and
...
@@ -291,6 +306,10 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -291,6 +306,10 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
layers. If set False, output of attention and intermediate dense layers is
layers. If set False, output of attention and intermediate dense layers is
normalized.
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer. If
larger than 0.0, intermediate_dropout_layer is created and used after
intermediate_activation_layer. Otherwise, intermediate_dropout_layer is
None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -310,6 +329,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -310,6 +329,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
use_bias
=
True
,
use_bias
=
True
,
norm_first
=
False
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
norm_epsilon
=
1e-12
,
intermediate_dropout
=
0.0
,
**
kwargs
):
**
kwargs
):
super
(
TransformerDecoderLayer
,
self
).
__init__
(
**
kwargs
)
super
(
TransformerDecoderLayer
,
self
).
__init__
(
**
kwargs
)
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_heads
=
num_attention_heads
...
@@ -329,6 +349,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -329,6 +349,7 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
self
.
_use_bias
=
use_bias
self
.
_use_bias
=
use_bias
self
.
_norm_first
=
norm_first
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
self
.
_norm_epsilon
=
norm_epsilon
self
.
_intermediate_dropout
=
intermediate_dropout
if
self
.
multi_channel_cross_attention
:
if
self
.
multi_channel_cross_attention
:
self
.
_cross_attention_cls
=
multi_channel_attention
.
MultiChannelAttention
self
.
_cross_attention_cls
=
multi_channel_attention
.
MultiChannelAttention
else
:
else
:
...
@@ -401,6 +422,11 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -401,6 +422,11 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
**
common_kwargs
)
**
common_kwargs
)
self
.
intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
intermediate_activation
)
self
.
intermediate_activation
)
if
self
.
_intermediate_dropout
>
0.0
:
self
.
intermediate_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_intermediate_dropout
)
else
:
self
.
intermediate_dropout_layer
=
None
self
.
output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
"abc,cd->abd"
,
output_shape
=
(
None
,
hidden_size
),
output_shape
=
(
None
,
hidden_size
),
...
@@ -445,7 +471,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -445,7 +471,9 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"norm_first"
:
"norm_first"
:
self
.
_norm_first
,
self
.
_norm_first
,
"norm_epsilon"
:
"norm_epsilon"
:
self
.
_norm_epsilon
self
.
_norm_epsilon
,
"intermediate_dropout"
:
self
.
_intermediate_dropout
}
}
base_config
=
super
(
TransformerDecoderLayer
,
self
).
get_config
()
base_config
=
super
(
TransformerDecoderLayer
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
@@ -508,6 +536,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -508,6 +536,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
intermediate_output
=
self
.
intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
intermediate_activation_layer
(
intermediate_output
=
self
.
intermediate_activation_layer
(
intermediate_output
)
intermediate_output
)
if
self
.
intermediate_dropout_layer
:
intermediate_output
=
self
.
intermediate_dropout_layer
(
intermediate_output
)
layer_output
=
self
.
output_dense
(
intermediate_output
)
layer_output
=
self
.
output_dense
(
intermediate_output
)
layer_output
=
self
.
output_dropout
(
layer_output
)
layer_output
=
self
.
output_dropout
(
layer_output
)
if
self
.
_norm_first
:
if
self
.
_norm_first
:
...
...
official/nlp/modeling/layers/transformer_test.py
View file @
dc03c043
...
@@ -230,7 +230,8 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
...
@@ -230,7 +230,8 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
attention_dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
use_bias
=
False
,
use_bias
=
False
,
norm_first
=
True
,
norm_first
=
True
,
norm_epsilon
=
1e-6
)
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.1
)
# Forward path.
# Forward path.
dummy_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
dummy_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
4
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
4
],
dtype
=
tf
.
float32
)
...
@@ -248,7 +249,8 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
...
@@ -248,7 +249,8 @@ class TransformerArgumentTest(keras_parameterized.TestCase):
attention_dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
use_bias
=
False
,
use_bias
=
False
,
norm_first
=
True
,
norm_first
=
True
,
norm_epsilon
=
1e-6
)
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.1
)
encoder_block_config
=
encoder_block
.
get_config
()
encoder_block_config
=
encoder_block
.
get_config
()
new_encoder_block
=
transformer
.
Transformer
.
from_config
(
new_encoder_block
=
transformer
.
Transformer
.
from_config
(
encoder_block_config
)
encoder_block_config
)
...
@@ -299,7 +301,8 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
...
@@ -299,7 +301,8 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
attention_dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
use_bias
=
False
,
use_bias
=
False
,
norm_first
=
True
,
norm_first
=
True
,
norm_epsilon
=
1e-6
)
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.1
)
# Forward path.
# Forward path.
dummy_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
dummy_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
4
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
4
],
dtype
=
tf
.
float32
)
...
@@ -317,7 +320,8 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
...
@@ -317,7 +320,8 @@ class TransformerDecoderLayerTest(keras_parameterized.TestCase):
attention_dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
use_bias
=
False
,
use_bias
=
False
,
norm_first
=
True
,
norm_first
=
True
,
norm_epsilon
=
1e-6
)
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.1
)
decoder_block_config
=
decoder_block
.
get_config
()
decoder_block_config
=
decoder_block
.
get_config
()
new_decoder_block
=
transformer
.
TransformerDecoderLayer
.
from_config
(
new_decoder_block
=
transformer
.
TransformerDecoderLayer
.
from_config
(
decoder_block_config
)
decoder_block_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment