Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a39f18f9
Commit
a39f18f9
authored
Dec 03, 2021
by
Yuexin Wu
Committed by
A. Unique TensorFlower
Dec 03, 2021
Browse files
Allow Funnel Transformer to switch between basic TransformerBlocks (added ReZero).
PiperOrigin-RevId: 414027254
parent
e293e338
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
169 additions
and
54 deletions
+169
-54
official/nlp/modeling/layers/rezero_transformer.py
official/nlp/modeling/layers/rezero_transformer.py
+25
-1
official/nlp/modeling/layers/rezero_transformer_test.py
official/nlp/modeling/layers/rezero_transformer_test.py
+6
-2
official/nlp/modeling/layers/transformer_encoder_block.py
official/nlp/modeling/layers/transformer_encoder_block.py
+4
-1
official/nlp/modeling/layers/util.py
official/nlp/modeling/layers/util.py
+28
-2
official/nlp/modeling/networks/funnel_transformer.py
official/nlp/modeling/networks/funnel_transformer.py
+88
-39
official/nlp/modeling/networks/funnel_transformer_test.py
official/nlp/modeling/networks/funnel_transformer_test.py
+18
-9
No files found.
official/nlp/modeling/layers/rezero_transformer.py
View file @
a39f18f9
...
@@ -18,6 +18,8 @@
...
@@ -18,6 +18,8 @@
import
gin
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
util
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
gin
.
configurable
@
gin
.
configurable
...
@@ -45,6 +47,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -45,6 +47,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
kernel_constraint: Constraint for dense layer kernels.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_layer_norm: If add layer_norm on top of the ReZero.
use_layer_norm: If add layer_norm on top of the ReZero.
share_rezero: If attention layer and FFN layer share the same alpha.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -62,7 +65,14 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -62,7 +65,14 @@ class ReZeroTransformer(tf.keras.layers.Layer):
kernel_constraint
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
bias_constraint
=
None
,
use_layer_norm
=
False
,
use_layer_norm
=
False
,
share_rezero
=
True
,
**
kwargs
):
**
kwargs
):
# attention_dropout will override attention_dropout_rate.
# This is to unify the input params with TransformerEncoderBlock.
attention_dropout_rate
=
kwargs
.
pop
(
"attention_dropout"
,
attention_dropout_rate
)
dropout_rate
=
kwargs
.
pop
(
"output_dropout"
,
dropout_rate
)
util
.
filter_kwargs
(
kwargs
)
super
(
ReZeroTransformer
,
self
).
__init__
(
**
kwargs
)
super
(
ReZeroTransformer
,
self
).
__init__
(
**
kwargs
)
self
.
_num_heads
=
num_attention_heads
self
.
_num_heads
=
num_attention_heads
...
@@ -78,6 +88,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -78,6 +88,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_use_layer_norm
=
use_layer_norm
self
.
_use_layer_norm
=
use_layer_norm
self
.
_share_rezero
=
share_rezero
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
if
isinstance
(
input_shape
,
tf
.
TensorShape
):
if
isinstance
(
input_shape
,
tf
.
TensorShape
):
...
@@ -165,6 +176,15 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -165,6 +176,15 @@ class ReZeroTransformer(tf.keras.layers.Layer):
trainable
=
True
,
trainable
=
True
,
dtype
=
tf
.
float32
)
dtype
=
tf
.
float32
)
if
self
.
_share_rezero
:
self
.
_rezero_a_ffn
=
self
.
_rezero_a
else
:
self
.
_rezero_a_ffn
=
self
.
add_weight
(
name
=
"rezero_alpha_ffn"
,
initializer
=
tf
.
keras
.
initializers
.
Zeros
(),
trainable
=
True
,
dtype
=
tf
.
float32
)
super
(
ReZeroTransformer
,
self
).
build
(
input_shape
)
super
(
ReZeroTransformer
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
def
get_config
(
self
):
...
@@ -183,6 +203,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -183,6 +203,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self
.
_output_range
,
self
.
_output_range
,
"use_layer_norm"
:
"use_layer_norm"
:
self
.
_use_layer_norm
,
self
.
_use_layer_norm
,
"share_rezero"
:
self
.
_share_rezero
,
"kernel_initializer"
:
"kernel_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
"bias_initializer"
:
"bias_initializer"
:
...
@@ -203,6 +225,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -203,6 +225,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
def
reset_rezero
(
self
):
def
reset_rezero
(
self
):
self
.
_rezero_a
.
assign
(
0.
)
self
.
_rezero_a
.
assign
(
0.
)
if
not
self
.
_share_rezero
:
self
.
_rezero_a_ffn
.
assign
(
0.
)
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
isinstance
(
inputs
,
(
list
,
tuple
)):
...
@@ -243,7 +267,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -243,7 +267,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
layer_output
=
self
.
_output_dropout
(
layer_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
# During mixed precision training, attention_output is from layer norm and
# During mixed precision training, attention_output is from layer norm and
# is always fp32 for now. Cast layer_output to fp32 for the subsequent add.
# is always fp32 for now. Cast layer_output to fp32 for the subsequent add.
layer_output
=
attention_output
+
tf
.
cast
(
self
.
_rezero_a
*
layer_output
,
layer_output
=
attention_output
+
tf
.
cast
(
self
.
_rezero_a
_ffn
*
layer_output
,
tf
.
float32
)
tf
.
float32
)
if
self
.
_use_layer_norm
:
if
self
.
_use_layer_norm
:
layer_output
=
self
.
_output_layer_norm
(
layer_output
)
layer_output
=
self
.
_output_layer_norm
(
layer_output
)
...
...
official/nlp/modeling/layers/rezero_transformer_test.py
View file @
a39f18f9
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
"""Tests for Keras-based rezero-transformer block layer."""
"""Tests for Keras-based rezero-transformer block layer."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -30,12 +31,15 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
...
@@ -30,12 +31,15 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
super
(
TransformerWithReZeroLayerTest
,
self
).
tearDown
()
super
(
TransformerWithReZeroLayerTest
,
self
).
tearDown
()
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'float32'
)
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'float32'
)
def
test_layer_invocation_with_float16_dtype
(
self
):
@
parameterized
.
named_parameters
((
'no_share_attn_ffn'
,
False
),
(
'share_attn_ffn'
,
True
))
def
test_layer_invocation_with_float16_dtype
(
self
,
share_rezero
):
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'mixed_float16'
)
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'mixed_float16'
)
test_layer
=
rezero_transformer
.
ReZeroTransformer
(
test_layer
=
rezero_transformer
.
ReZeroTransformer
(
num_attention_heads
=
10
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
)
intermediate_activation
=
'relu'
,
share_rezero
=
share_rezero
)
sequence_length
=
21
sequence_length
=
21
width
=
80
width
=
80
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
...
...
official/nlp/modeling/layers/transformer_encoder_block.py
View file @
a39f18f9
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
util
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
TransformerEncoderBlock
(
tf
.
keras
.
layers
.
Layer
):
class
TransformerEncoderBlock
(
tf
.
keras
.
layers
.
Layer
):
...
@@ -86,8 +88,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
...
@@ -86,8 +88,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
kernel.
kernel.
attention_axes: axes over which the attention is applied. `None` means
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments
/
**kwargs: keyword arguments
.
"""
"""
util
.
filter_kwargs
(
kwargs
)
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
_num_heads
=
num_attention_heads
self
.
_num_heads
=
num_attention_heads
...
...
official/nlp/modeling/layers/util.py
View file @
a39f18f9
...
@@ -30,13 +30,13 @@ class TfFunctionIfEagerDecorator(object):
...
@@ -30,13 +30,13 @@ class TfFunctionIfEagerDecorator(object):
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
def
wrapped_func
(
*
args
):
def
wrapped_func
(
*
args
):
# TODO(b/150147476, b/150024785): Fix tf.function in TF1 crash.
# TODO(b/150147476, b/150024785): Fix tf.function in TF1 crash.
if
not
hasattr
(
tf
.
compat
.
v1
,
"
executing_eagerly_outside_functions
"
if
not
hasattr
(
tf
.
compat
.
v1
,
'
executing_eagerly_outside_functions
'
)
or
tf
.
compat
.
v1
.
executing_eagerly_outside_functions
():
)
or
tf
.
compat
.
v1
.
executing_eagerly_outside_functions
():
return
tf
.
function
(
func
=
func
,
**
self
.
func_kwargs
)(
*
args
)
return
tf
.
function
(
func
=
func
,
**
self
.
func_kwargs
)(
*
args
)
return
func
(
*
args
)
return
func
(
*
args
)
# Cache the created function in self._call_impl.
# Cache the created function in self._call_impl.
if
not
hasattr
(
self
,
"
_call_impl
"
):
if
not
hasattr
(
self
,
'
_call_impl
'
):
self
.
_call_impl
=
wrapped_func
self
.
_call_impl
=
wrapped_func
return
self
.
_call_impl
return
self
.
_call_impl
...
@@ -44,3 +44,29 @@ class TfFunctionIfEagerDecorator(object):
...
@@ -44,3 +44,29 @@ class TfFunctionIfEagerDecorator(object):
def
tf_function_if_eager
(
**
kwargs
):
def
tf_function_if_eager
(
**
kwargs
):
"""Applies the @tf.function decorator only if running in eager mode."""
"""Applies the @tf.function decorator only if running in eager mode."""
return
TfFunctionIfEagerDecorator
(
**
kwargs
)
return
TfFunctionIfEagerDecorator
(
**
kwargs
)
def
filter_kwargs
(
kwargs
):
"""In place removes unused options in kwargs.
This function removes the construction signatures: e.g.
number_attention_heads... in TransformerEncoderBlock. This is needed,
otherwise base_layer.py in Keras will complain.
Args:
kwargs: keyword arguments to be filtered.
"""
# This is the union of signatures of TransformerEncoderBlock and
# ReZeroTransformer. Every Transformer
# block that uses compatible signature with TransformerEncoderBlock should
# call this function before base constructor super().__init__(**kwargs).
denylist
=
[
'num_attention_heads'
,
'intermediate_size'
,
'intermediate_activation'
,
'inner_dim'
,
'inner_activation'
,
'output_range'
,
'kernel_initializer'
,
'bias_initializer'
,
'kernel_regularizer'
,
'bias_regularizer'
,
'activity_regularizer'
,
'kernel_constraint'
,
'bias_constraint'
,
'use_bias'
,
'norm_first'
,
'norm_epsilon'
,
'output_dropout'
,
'attention_dropout'
,
'inner_dropout'
,
'attention_initializer'
,
'attention_axes'
,
'share_rezero'
]
for
unused_key
in
denylist
:
kwargs
.
pop
(
unused_key
,
None
)
official/nlp/modeling/networks/funnel_transformer.py
View file @
a39f18f9
...
@@ -15,17 +15,32 @@
...
@@ -15,17 +15,32 @@
"""Funnel Transformer network."""
"""Funnel Transformer network."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
from
typing
import
Union
,
Sequence
from
typing
import
Any
,
Callable
,
Optional
,
Union
,
Sequence
from
absl
import
logging
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
layers
_Initializer
=
Union
[
str
,
tf
.
keras
.
initializers
.
Initializer
]
_Activation
=
Union
[
str
,
Callable
[...,
Any
]]
_MAX
=
'max'
_MAX
=
'max'
_AVG
=
'avg'
_AVG
=
'avg'
_TRUNCATED_AVG
=
'truncated_avg'
_TRUNCATED_AVG
=
'truncated_avg'
_transformer_cls2str
=
{
layers
.
TransformerEncoderBlock
:
'TransformerEncoderBlock'
,
layers
.
ReZeroTransformer
:
'ReZeroTransformer'
}
_str2transformer_cls
=
{
'TransformerEncoderBlock'
:
layers
.
TransformerEncoderBlock
,
'ReZeroTransformer'
:
layers
.
ReZeroTransformer
}
_approx_gelu
=
lambda
x
:
tf
.
keras
.
activations
.
gelu
(
x
,
approximate
=
True
)
def
_get_policy_dtype
():
def
_get_policy_dtype
():
try
:
try
:
...
@@ -206,29 +221,37 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -206,29 +221,37 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
embeddings for the input word IDs.
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
layers. If set False, output of attention and intermediate dense layers is
normalized.
normalized. This does not apply to ReZero.
transformer_cls: str or a keras Layer. This is the base TransformerBlock the
funnel encoder relies on.
share_rezero: bool. Whether to share ReZero alpha between the attention
layer and the ffn layer. This option is specific to ReZero.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size
,
vocab_size
:
int
,
hidden_size
=
768
,
hidden_size
:
int
=
768
,
num_layers
=
12
,
num_layers
:
int
=
12
,
num_attention_heads
=
12
,
num_attention_heads
:
int
=
12
,
max_sequence_length
=
512
,
max_sequence_length
:
int
=
512
,
type_vocab_size
=
16
,
type_vocab_size
:
int
=
16
,
inner_dim
=
3072
,
inner_dim
:
int
=
3072
,
inner_activation
=
lambda
x
:
tf
.
keras
.
activations
.
gelu
(
x
,
approximate
=
True
),
inner_activation
:
_Activation
=
_approx_gelu
,
output_dropout
=
0.1
,
output_dropout
:
float
=
0.1
,
attention_dropout
=
0.1
,
attention_dropout
:
float
=
0.1
,
pool_type
=
_MAX
,
pool_type
:
str
=
_MAX
,
pool_stride
=
2
,
pool_stride
:
int
=
2
,
unpool_length
=
0
,
unpool_length
:
int
=
0
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
initializer
:
_Initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
output_range
=
None
,
stddev
=
0.02
),
embedding_width
=
None
,
output_range
:
Optional
[
int
]
=
None
,
embedding_layer
=
None
,
embedding_width
:
Optional
[
int
]
=
None
,
norm_first
=
False
,
embedding_layer
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
norm_first
:
bool
=
False
,
transformer_cls
:
Union
[
str
,
tf
.
keras
.
layers
.
Layer
]
=
layers
.
TransformerEncoderBlock
,
share_rezero
:
bool
=
True
,
**
kwargs
):
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
activation
=
tf
.
keras
.
activations
.
get
(
inner_activation
)
activation
=
tf
.
keras
.
activations
.
get
(
inner_activation
)
...
@@ -278,16 +301,22 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -278,16 +301,22 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self
.
_transformer_layers
=
[]
self
.
_transformer_layers
=
[]
self
.
_attention_mask_layer
=
layers
.
SelfAttentionMask
(
self
.
_attention_mask_layer
=
layers
.
SelfAttentionMask
(
name
=
'self_attention_mask'
)
name
=
'self_attention_mask'
)
# Will raise an error if the string is not supported.
if
isinstance
(
transformer_cls
,
str
):
transformer_cls
=
_str2transformer_cls
[
transformer_cls
]
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
layer
=
layers
.
T
ransformer
EncoderBlock
(
layer
=
t
ransformer
_cls
(
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
inner_dim
,
inner_dim
=
inner_dim
,
inner_dim
=
inner_dim
,
intermediate_activation
=
inner_activation
,
inner_activation
=
inner_activation
,
inner_activation
=
inner_activation
,
output_dropout
=
output_dropout
,
output_dropout
=
output_dropout
,
attention_dropout
=
attention_dropout
,
attention_dropout
=
attention_dropout
,
norm_first
=
norm_first
,
norm_first
=
norm_first
,
output_range
=
output_range
if
i
==
num_layers
-
1
else
None
,
output_range
=
output_range
if
i
==
num_layers
-
1
else
None
,
kernel_initializer
=
initializer
,
kernel_initializer
=
initializer
,
share_rezero
=
share_rezero
,
name
=
'transformer/layer_%d'
%
i
)
name
=
'transformer/layer_%d'
%
i
)
self
.
_transformer_layers
.
append
(
layer
)
self
.
_transformer_layers
.
append
(
layer
)
...
@@ -333,24 +362,44 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -333,24 +362,44 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self
.
_pool_type
=
pool_type
self
.
_pool_type
=
pool_type
self
.
_config
=
{
self
.
_config
=
{
'vocab_size'
:
vocab_size
,
'vocab_size'
:
'hidden_size'
:
hidden_size
,
vocab_size
,
'num_layers'
:
num_layers
,
'hidden_size'
:
'num_attention_heads'
:
num_attention_heads
,
hidden_size
,
'max_sequence_length'
:
max_sequence_length
,
'num_layers'
:
'type_vocab_size'
:
type_vocab_size
,
num_layers
,
'inner_dim'
:
inner_dim
,
'num_attention_heads'
:
'inner_activation'
:
tf
.
keras
.
activations
.
serialize
(
activation
),
num_attention_heads
,
'output_dropout'
:
output_dropout
,
'max_sequence_length'
:
'attention_dropout'
:
attention_dropout
,
max_sequence_length
,
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
initializer
),
'type_vocab_size'
:
'output_range'
:
output_range
,
type_vocab_size
,
'embedding_width'
:
embedding_width
,
'inner_dim'
:
'embedding_layer'
:
embedding_layer
,
inner_dim
,
'norm_first'
:
norm_first
,
'inner_activation'
:
'pool_type'
:
pool_type
,
tf
.
keras
.
activations
.
serialize
(
activation
),
'pool_stride'
:
pool_stride
,
'output_dropout'
:
'unpool_length'
:
unpool_length
,
output_dropout
,
'attention_dropout'
:
attention_dropout
,
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
initializer
),
'output_range'
:
output_range
,
'embedding_width'
:
embedding_width
,
'embedding_layer'
:
embedding_layer
,
'norm_first'
:
norm_first
,
'pool_type'
:
pool_type
,
'pool_stride'
:
pool_stride
,
'unpool_length'
:
unpool_length
,
'transformer_cls'
:
_transformer_cls2str
.
get
(
transformer_cls
,
str
(
transformer_cls
))
}
}
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
...
...
official/nlp/modeling/networks/funnel_transformer_test.py
View file @
a39f18f9
...
@@ -38,13 +38,20 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -38,13 +38,20 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
tf
.
keras
.
mixed_precision
.
set_global_policy
(
"float32"
)
tf
.
keras
.
mixed_precision
.
set_global_policy
(
"float32"
)
@
parameterized
.
named_parameters
(
@
parameterized
.
named_parameters
(
(
"mix_truncated_avg"
,
"mixed_float16"
,
tf
.
float16
,
"truncated_avg"
),
(
"mix_truncated_avg_rezero"
,
"mixed_float16"
,
tf
.
float16
,
"truncated_avg"
,
(
"float32_truncated_avg"
,
"float32"
,
tf
.
float32
,
"truncated_avg"
),
"ReZeroTransformer"
),
(
"float32_truncated_avg_rezero"
,
"float32"
,
(
"mix_max"
,
"mixed_float16"
,
tf
.
float16
,
"max"
),
tf
.
float32
,
"truncated_avg"
,
"ReZeroTransformer"
),
(
"float32_max"
,
"float32"
,
tf
.
float32
,
"max"
),
(
"mix_truncated_avg"
,
"mixed_float16"
,
tf
.
float16
,
"truncated_avg"
,
(
"mix_avg"
,
"mixed_float16"
,
tf
.
float16
,
"avg"
),
"TransformerEncoderBlock"
),
(
"float32_avg"
,
"float32"
,
tf
.
float32
,
"avg"
))
(
"float32_truncated_avg"
,
"float32"
,
tf
.
float32
,
"truncated_avg"
,
def
test_network_creation
(
self
,
policy
,
pooled_dtype
,
pool_type
):
"TransformerEncoderBlock"
),
(
"mix_max"
,
"mixed_float16"
,
tf
.
float16
,
"max"
,
"TransformerEncoderBlock"
),
(
"float32_max"
,
"float32"
,
tf
.
float32
,
"max"
,
"TransformerEncoderBlock"
),
(
"mix_avg"
,
"mixed_float16"
,
tf
.
float16
,
"avg"
,
"TransformerEncoderBlock"
),
(
"float32_avg"
,
"float32"
,
tf
.
float32
,
"avg"
,
"TransformerEncoderBlock"
))
def
test_network_creation
(
self
,
policy
,
pooled_dtype
,
pool_type
,
transformer_cls
):
tf
.
keras
.
mixed_precision
.
set_global_policy
(
policy
)
tf
.
keras
.
mixed_precision
.
set_global_policy
(
policy
)
hidden_size
=
32
hidden_size
=
32
...
@@ -60,7 +67,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -60,7 +67,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
pool_stride
=
pool_stride
,
pool_stride
=
pool_stride
,
pool_type
=
pool_type
,
pool_type
=
pool_type
,
max_sequence_length
=
sequence_length
,
max_sequence_length
=
sequence_length
,
unpool_length
=
0
)
unpool_length
=
0
,
transformer_cls
=
transformer_cls
)
# Create the inputs (note that the first dimension is implicit).
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
...
@@ -253,7 +261,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -253,7 +261,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
norm_first
=
False
,
norm_first
=
False
,
pool_type
=
"max"
,
pool_type
=
"max"
,
pool_stride
=
2
,
pool_stride
=
2
,
unpool_length
=
0
)
unpool_length
=
0
,
transformer_cls
=
"TransformerEncoderBlock"
)
network
=
funnel_transformer
.
FunnelTransformerEncoder
(
**
kwargs
)
network
=
funnel_transformer
.
FunnelTransformerEncoder
(
**
kwargs
)
expected_config
=
dict
(
kwargs
)
expected_config
=
dict
(
kwargs
)
expected_config
[
"inner_activation"
]
=
tf
.
keras
.
activations
.
serialize
(
expected_config
[
"inner_activation"
]
=
tf
.
keras
.
activations
.
serialize
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment