Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8eb91073
Commit
8eb91073
authored
Jun 02, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Jun 02, 2020
Browse files
Internal change
PiperOrigin-RevId: 314373769
parent
c25c3e88
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
359 additions
and
7 deletions
+359
-7
official/nlp/modeling/layers/README.md
official/nlp/modeling/layers/README.md
+4
-0
official/nlp/modeling/layers/__init__.py
official/nlp/modeling/layers/__init__.py
+1
-0
official/nlp/modeling/layers/gated_feedforward.py
official/nlp/modeling/layers/gated_feedforward.py
+200
-0
official/nlp/modeling/layers/gated_feedforward_test.py
official/nlp/modeling/layers/gated_feedforward_test.py
+127
-0
official/nlp/modeling/layers/transformer_scaffold.py
official/nlp/modeling/layers/transformer_scaffold.py
+27
-7
No files found.
official/nlp/modeling/layers/README.md
View file @
8eb91073
...
@@ -47,3 +47,7 @@ assemble new layers, networks, or models.
...
@@ -47,3 +47,7 @@ assemble new layers, networks, or models.
*
[
ClassificationHead
](
cls_head.py
)
A pooling head over a sequence of
*
[
ClassificationHead
](
cls_head.py
)
A pooling head over a sequence of
embeddings, commonly used by classification tasks.
embeddings, commonly used by classification tasks.
*
[
GatedFeedforward
](
gated_feedforward.py
)
implements the gated linear layer
feedforward as described in
[
"GLU Variants Improve Transformer"
](
https://arxiv.org/abs/2002.05202
)
.
official/nlp/modeling/layers/__init__.py
View file @
8eb91073
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
from
official.nlp.modeling.layers.attention
import
*
from
official.nlp.modeling.layers.attention
import
*
from
official.nlp.modeling.layers.cls_head
import
*
from
official.nlp.modeling.layers.cls_head
import
*
from
official.nlp.modeling.layers.dense_einsum
import
DenseEinsum
from
official.nlp.modeling.layers.dense_einsum
import
DenseEinsum
from
official.nlp.modeling.layers.gated_feedforward
import
GatedFeedforward
from
official.nlp.modeling.layers.masked_softmax
import
MaskedSoftmax
from
official.nlp.modeling.layers.masked_softmax
import
MaskedSoftmax
from
official.nlp.modeling.layers.on_device_embedding
import
OnDeviceEmbedding
from
official.nlp.modeling.layers.on_device_embedding
import
OnDeviceEmbedding
from
official.nlp.modeling.layers.position_embedding
import
PositionEmbedding
from
official.nlp.modeling.layers.position_embedding
import
PositionEmbedding
...
...
official/nlp/modeling/layers/gated_feedforward.py
0 → 100644
View file @
8eb91073
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based gated feedforward layer."""
# pylint: disable=g-classes-have-attributes
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
gin
import
tensorflow
as
tf
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
gin
.
configurable
class
GatedFeedforward
(
tf
.
keras
.
layers
.
Layer
):
"""Gated linear feedforward layer.
This layer follows the paper "GLU Variants Improve Transformer"
(https://arxiv.org/abs/2002.05202). In additional, it allows to stack
multiple feedforward blocks and specify the position of dropout layer.
Arguments:
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout: Dropout probability for the output dropout.
use_gate: Whether to use gated linear units. If True, assuming `GELU` as
the activation and omitting bias, will apply
`GEGLU(x, W, V, W_2) = (GEGLU(xW) * xV)W2`; if False, will follow
"Attention Is All You Need" (https://arxiv.org/abs/1706.03762) paper
and apply `FFN(x, W, W_2) = GELU(xW_1)W_2.`
num_blocks: The number of feedforward blocks to stack. Each block contains
a (gated) linear layer and a fully connected layer followed by dropout,
layer norm and residual.
dropout_position: Where to apply the dropout, the value can be either
`before_residual` or `after_residual`. If `before_residual`, will apply
`layer_output = layer_norm(dropout(layer_output) + layer_input)`;
if `after residual`, will apply
`layer_output = dropout(layer_norm(layer_output + layer_input))`.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
def
__init__
(
self
,
intermediate_size
,
intermediate_activation
,
dropout
,
use_gate
=
True
,
num_blocks
=
1
,
dropout_position
=
"before_residual"
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
**
kwargs
):
super
(
GatedFeedforward
,
self
).
__init__
(
**
kwargs
)
self
.
_intermediate_size
=
intermediate_size
self
.
_intermediate_activation
=
intermediate_activation
self
.
_dropout
=
dropout
self
.
_use_gate
=
use_gate
self
.
_num_blocks
=
num_blocks
self
.
_dropout_position
=
dropout_position
if
self
.
_dropout_position
not
in
(
"before_residual"
,
"after_residual"
):
raise
ValueError
(
"The dropout_position should be either `before_residual` or"
"`after_residual`, got: %s"
%
self
.
_dropout_position
)
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
def
build
(
self
,
input_shape
):
hidden_size
=
input_shape
.
as_list
()[
-
1
]
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
self
.
_intermediate_dense
=
[]
self
.
_gate_dense
=
[]
self
.
_output_dense
=
[]
self
.
_output_dropout
=
[]
self
.
_output_layer_norm
=
[]
for
i
in
range
(
self
.
_num_blocks
):
self
.
_intermediate_dense
.
append
(
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
self
.
_intermediate_size
),
bias_axes
=
"d"
,
activation
=
self
.
_intermediate_activation
,
name
=
"intermediate_%d"
%
i
,
**
common_kwargs
))
if
self
.
_use_gate
:
self
.
_gate_dense
.
append
(
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
self
.
_intermediate_size
),
bias_axes
=
"d"
,
name
=
"gate_%d"
%
i
,
**
common_kwargs
))
self
.
_output_dense
.
append
(
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
name
=
"output_%d"
%
i
,
**
common_kwargs
))
self
.
_output_dropout
.
append
(
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
))
# Use float32 in layernorm for numeric stability.
self
.
_output_layer_norm
.
append
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm_%d"
%
i
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
))
def
get_config
(
self
):
config
=
{
"intermediate_size"
:
self
.
_intermediate_size
,
"intermediate_activation"
:
self
.
_intermediate_activation
,
"dropout"
:
self
.
_dropout
,
"use_gate"
:
self
.
_use_gate
,
"num_blocks"
:
self
.
_num_blocks
,
"dropout_position"
:
self
.
_dropout_position
,
"kernel_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
"bias_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_bias_initializer
),
"kernel_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_kernel_regularizer
),
"bias_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_bias_regularizer
),
"activity_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_activity_regularizer
),
"kernel_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
)
}
base_config
=
super
(
GatedFeedforward
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
layer_output
=
inputs
for
i
in
range
(
self
.
_num_blocks
):
layer_input
=
layer_output
intermediate_output
=
self
.
_intermediate_dense
[
i
](
layer_input
)
if
self
.
_use_gate
:
gated_linear
=
self
.
_gate_dense
[
i
](
layer_input
)
intermediate_output
=
intermediate_output
*
gated_linear
layer_output
=
self
.
_output_dense
[
i
](
intermediate_output
)
if
self
.
_dropout_position
==
"before_residual"
:
layer_output
=
self
.
_output_dropout
[
i
](
layer_output
)
# During mixed precision training, `layer_input` may be from layer norm.
# If so, it is always fp32. Cast layer_output to fp32 for the subsequent
# add.
if
layer_input
.
dtype
==
tf
.
float32
:
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
layer_output
=
self
.
_output_layer_norm
[
i
](
layer_output
+
layer_input
)
if
self
.
_dropout_position
==
"after_residual"
:
layer_output
=
self
.
_output_dropout
[
i
](
layer_output
)
return
layer_output
official/nlp/modeling/layers/gated_feedforward_test.py
0 → 100644
View file @
8eb91073
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras-based gated feedforward layer."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.modeling.layers
import
gated_feedforward
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
class
GatedFeedforwardTest
(
keras_parameterized
.
TestCase
):
def
tearDown
(
self
):
super
(
GatedFeedforwardTest
,
self
).
tearDown
()
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
"float32"
)
@
parameterized
.
parameters
(
(
True
,
1
,
"after_residual"
,
"float32"
),
(
True
,
1
,
"after_residual"
,
"mixed_float16"
),
(
False
,
4
,
"before_residual"
,
"float32"
),
(
False
,
4
,
"before_residual"
,
"mixed_float16"
),
(
True
,
4
,
"after_residual"
,
"float32"
),
(
True
,
4
,
"after_residual"
,
"mixed_float16"
),
(
False
,
1
,
"before_residual"
,
"float32"
),
(
False
,
1
,
"before_residual"
,
"mixed_float16"
),
)
def
test_layer_creation
(
self
,
use_gate
,
num_blocks
,
dropout_position
,
dtype
):
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
dtype
)
kwargs
=
dict
(
intermediate_size
=
128
,
intermediate_activation
=
"relu"
,
dropout
=
0.1
,
use_gate
=
use_gate
,
num_blocks
=
num_blocks
,
dropout_position
=
dropout_position
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
)
test_layer
=
gated_feedforward
.
GatedFeedforward
(
**
kwargs
)
sequence_length
=
64
width
=
128
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
output_tensor
=
test_layer
(
data_tensor
)
# The default output of a transformer layer should be the same as the input.
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output_tensor
.
shape
.
as_list
())
@
parameterized
.
parameters
(
(
True
,
1
,
"after_residual"
,
"float32"
),
(
True
,
1
,
"after_residual"
,
"mixed_float16"
),
(
False
,
4
,
"before_residual"
,
"float32"
),
(
False
,
4
,
"before_residual"
,
"mixed_float16"
),
(
True
,
4
,
"after_residual"
,
"float32"
),
(
True
,
4
,
"after_residual"
,
"mixed_float16"
),
(
False
,
1
,
"before_residual"
,
"float32"
),
(
False
,
1
,
"before_residual"
,
"mixed_float16"
),
)
def
test_layer_invocation
(
self
,
use_gate
,
num_blocks
,
dropout_position
,
dtype
):
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
dtype
)
kwargs
=
dict
(
intermediate_size
=
16
,
intermediate_activation
=
"relu"
,
dropout
=
0.1
,
use_gate
=
use_gate
,
num_blocks
=
num_blocks
,
dropout_position
=
dropout_position
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
)
test_layer
=
gated_feedforward
.
GatedFeedforward
(
**
kwargs
)
sequence_length
=
16
width
=
32
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
width
))
output_tensor
=
test_layer
(
data_tensor
)
# Create a model from the test layer.
model
=
tf
.
keras
.
Model
(
data_tensor
,
output_tensor
)
# Invoke the model on test data.
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
output_data
=
model
.
predict
(
input_data
)
self
.
assertEqual
(
output_data
.
shape
,
(
batch_size
,
sequence_length
,
width
))
def
test_serialize_deserialize
(
self
):
kwargs
=
dict
(
intermediate_size
=
16
,
intermediate_activation
=
"relu"
,
dropout
=
0.1
,
use_gate
=
False
,
num_blocks
=
4
,
dropout_position
=
"after_residual"
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
)
test_layer
=
gated_feedforward
.
GatedFeedforward
(
**
kwargs
)
new_layer
=
gated_feedforward
.
GatedFeedforward
.
from_config
(
test_layer
.
get_config
())
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
test_layer
.
get_config
(),
new_layer
.
get_config
())
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/modeling/layers/transformer_scaffold.py
View file @
8eb91073
...
@@ -44,12 +44,30 @@ class TransformerScaffold(tf.keras.layers.Layer):
...
@@ -44,12 +44,30 @@ class TransformerScaffold(tf.keras.layers.Layer):
intermediate_activation: Activation for the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
attention_cls: A class to instantiate attention layer, or a layer instance.
attention_cls: A class to instantiate attention layer, or a layer instance.
attention_cfg: The config with which to instantiate `attention_cls`. Ignored
attention_cfg: The config with which to instantiate `attention_cls`. Ignored
if attention_cls is a layer instance.
if attention_cls is a layer instance or None. If `attention_cls` is a
class, but `attention_cfg` is None, following kwargs will be used to
instantiate the attention instance:
{
"num_heads": num_attention_heads,
"key_size": int(hidden_size // num_attention_heads),
"dropout": attention_dropout_rate,
"name": "self_attention"
}, where `hidden_size` is the input tensor's last dimension.
feedforward_cls: A class to instantiate feedforward layer, or a layer
feedforward_cls: A class to instantiate feedforward layer, or a layer
instance. If None, will use the standard feedforward layer as described
instance. If None, will use the standard feedforward layer as described
in "Attention Is All You Need" paper.
in "Attention Is All You Need" paper. If not None, the instantiated
feedforward layer is expected to take the output of attention as input
and its output is this transformer layer's output.
feedforward_cfg: The config with which to instantiate `feedforward_cls`.
feedforward_cfg: The config with which to instantiate `feedforward_cls`.
Ignored if feedforward_cls is a layer instance or is None.
Ignored if feedforward_cls is a layer instance or is None.
If `feedforward_cls` is a class, but `feedforward_cfg` is None, following
kwargs will be used to instantiate the feedforward instance:
{
"intermediate_size": intermediate_size,
"intermediate_activation": intermediate_activation,
"dropout": dropout_rate,
"name": "feedforward"
}.
dropout_rate: Dropout probability for the post-attention and output dropout.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
attention_dropout_rate: Dropout probability for within the attention layer.
kernel_initializer: Initializer for dense layer kernels.
kernel_initializer: Initializer for dense layer kernels.
...
@@ -156,6 +174,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
...
@@ -156,6 +174,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
default_feedforward_cfg
=
{
default_feedforward_cfg
=
{
"intermediate_size"
:
self
.
_intermediate_size
,
"intermediate_size"
:
self
.
_intermediate_size
,
"intermediate_activation"
:
self
.
_intermediate_activation
,
"intermediate_activation"
:
self
.
_intermediate_activation
,
"dropout"
:
self
.
_dropout_rate
,
"name"
:
"feedforward"
,
"name"
:
"feedforward"
,
}
}
default_feedforward_cfg
.
update
(
common_kwargs
)
default_feedforward_cfg
.
update
(
common_kwargs
)
...
@@ -245,12 +264,13 @@ class TransformerScaffold(tf.keras.layers.Layer):
...
@@ -245,12 +264,13 @@ class TransformerScaffold(tf.keras.layers.Layer):
if
self
.
_feedforward_block
is
None
:
if
self
.
_feedforward_block
is
None
:
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
# During mixed precision training, attention_output is from layer norm
# and is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
else
:
else
:
layer_output
=
self
.
_feedforward_block
(
attention_output
)
layer_output
=
self
.
_feedforward_block
(
attention_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent add.
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
return
layer_output
return
layer_output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment