Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5a4c4e18
Commit
5a4c4e18
authored
Jun 11, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Jun 11, 2020
Browse files
Internal Change
PiperOrigin-RevId: 316054828
parent
9cdb5d72
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
2 deletions
+21
-2
official/nlp/modeling/layers/gated_feedforward.py
official/nlp/modeling/layers/gated_feedforward.py
+11
-1
official/nlp/modeling/layers/transformer_scaffold.py
official/nlp/modeling/layers/transformer_scaffold.py
+10
-1
No files found.
official/nlp/modeling/layers/gated_feedforward.py
View file @
5a4c4e18
...
@@ -105,19 +105,27 @@ class GatedFeedforward(tf.keras.layers.Layer):
...
@@ -105,19 +105,27 @@ class GatedFeedforward(tf.keras.layers.Layer):
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
bias_constraint
=
self
.
_bias_constraint
)
self
.
_intermediate_dense
=
[]
self
.
_intermediate_dense
=
[]
self
.
_intermediate_activation_layers
=
[]
self
.
_gate_dense
=
[]
self
.
_gate_dense
=
[]
self
.
_output_dense
=
[]
self
.
_output_dense
=
[]
self
.
_output_dropout
=
[]
self
.
_output_dropout
=
[]
self
.
_output_layer_norm
=
[]
self
.
_output_layer_norm
=
[]
activation_policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
if
activation_policy
.
name
==
"mixed_bfloat16"
:
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
activation_policy
=
tf
.
float32
for
i
in
range
(
self
.
_num_blocks
):
for
i
in
range
(
self
.
_num_blocks
):
self
.
_intermediate_dense
.
append
(
self
.
_intermediate_dense
.
append
(
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
"abc,cd->abd"
,
output_shape
=
(
None
,
self
.
_intermediate_size
),
output_shape
=
(
None
,
self
.
_intermediate_size
),
bias_axes
=
"d"
,
bias_axes
=
"d"
,
activation
=
self
.
_intermediate_activation
,
name
=
"intermediate_%d"
%
i
,
name
=
"intermediate_%d"
%
i
,
**
common_kwargs
))
**
common_kwargs
))
self
.
_intermediate_activation_layers
.
append
(
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation
,
dtype
=
activation_policy
))
if
self
.
_use_gate
:
if
self
.
_use_gate
:
self
.
_gate_dense
.
append
(
self
.
_gate_dense
.
append
(
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
...
@@ -180,6 +188,8 @@ class GatedFeedforward(tf.keras.layers.Layer):
...
@@ -180,6 +188,8 @@ class GatedFeedforward(tf.keras.layers.Layer):
for
i
in
range
(
self
.
_num_blocks
):
for
i
in
range
(
self
.
_num_blocks
):
layer_input
=
layer_output
layer_input
=
layer_output
intermediate_output
=
self
.
_intermediate_dense
[
i
](
layer_input
)
intermediate_output
=
self
.
_intermediate_dense
[
i
](
layer_input
)
intermediate_output
=
self
.
_intermediate_activation_layers
[
i
](
intermediate_output
)
if
self
.
_use_gate
:
if
self
.
_use_gate
:
gated_linear
=
self
.
_gate_dense
[
i
](
layer_input
)
gated_linear
=
self
.
_gate_dense
[
i
](
layer_input
)
intermediate_output
=
intermediate_output
*
gated_linear
intermediate_output
=
intermediate_output
*
gated_linear
...
...
official/nlp/modeling/layers/transformer_scaffold.py
View file @
5a4c4e18
...
@@ -198,9 +198,16 @@ class TransformerScaffold(tf.keras.layers.Layer):
...
@@ -198,9 +198,16 @@ class TransformerScaffold(tf.keras.layers.Layer):
"abc,cd->abd"
,
"abc,cd->abd"
,
output_shape
=
(
None
,
self
.
_intermediate_size
),
output_shape
=
(
None
,
self
.
_intermediate_size
),
bias_axes
=
"d"
,
bias_axes
=
"d"
,
activation
=
self
.
_intermediate_activation
,
name
=
"intermediate"
,
name
=
"intermediate"
,
**
common_kwargs
)
**
common_kwargs
)
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
if
policy
.
name
==
"mixed_bfloat16"
:
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy
=
tf
.
float32
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation
,
dtype
=
policy
)
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
"abc,cd->abd"
,
output_shape
=
(
None
,
hidden_size
),
output_shape
=
(
None
,
hidden_size
),
...
@@ -263,6 +270,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
...
@@ -263,6 +270,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
attention_output
)
attention_output
)
if
self
.
_feedforward_block
is
None
:
if
self
.
_feedforward_block
is
None
:
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_activation_layer
(
intermediate_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
# During mixed precision training, attention_output is from layer norm
# During mixed precision training, attention_output is from layer norm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment