Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
52222935
Commit
52222935
authored
Oct 21, 2021
by
Yuexin Wu
Committed by
A. Unique TensorFlower
Oct 21, 2021
Browse files
Internal change
PiperOrigin-RevId: 404920138
parent
1f885e26
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
179 additions
and
41 deletions
+179
-41
official/nlp/modeling/networks/funnel_transformer.py
official/nlp/modeling/networks/funnel_transformer.py
+168
-39
official/nlp/modeling/networks/funnel_transformer_test.py
official/nlp/modeling/networks/funnel_transformer_test.py
+11
-2
No files found.
official/nlp/modeling/networks/funnel_transformer.py
View file @
52222935
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
"""Funnel Transformer network."""
"""Funnel Transformer network."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
from
typing
import
Union
,
Sequence
from
typing
import
Union
,
Sequence
from
absl
import
logging
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
...
@@ -21,6 +22,10 @@ import tensorflow as tf
...
@@ -21,6 +22,10 @@ import tensorflow as tf
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
layers
_MAX
=
'max'
_AVG
=
'avg'
_TRUNCATED_AVG
=
'truncated_avg'
def
_pool_and_concat
(
mask
,
unpool_length
:
int
,
strides
:
Union
[
Sequence
[
int
],
def
_pool_and_concat
(
mask
,
unpool_length
:
int
,
strides
:
Union
[
Sequence
[
int
],
int
],
int
],
...
@@ -63,6 +68,94 @@ def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int],
...
@@ -63,6 +68,94 @@ def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int],
return
mask
return
mask
def
_create_truncated_avg_transforms
(
seq_length
:
int
,
pool_strides
:
Sequence
[
int
]):
"""Computes pooling transforms.
The pooling_transform is of shape [seq_length,
seq_length//pool_stride] and
pooling_transform[i,j] = 1.0/pool_stride if i//pool_stride == j
0.0 otherwise.
It's in essense average pooling but truncate the final window if it
seq_length % pool_stride != 0.
For seq_length==6 and pool_stride==2, it is
[[ 0.5, 0.0, 0.0 ],
[ 0.5, 0.0, 0.0 ],
[ 0.0, 0.5, 0.0 ],
[ 0.0, 0.5, 0.0 ],
[ 0.0, 0.0, 0.5 ],
[ 0.0, 0.0, 0.5 ]]
Args:
seq_length: int, sequence length.
pool_strides: Sequence of pooling strides for each layer.
Returns:
pooling_transforms: Sequence of pooling transforms (Tensors) for each layer.
"""
pooling_transforms
=
[]
for
pool_stride
in
pool_strides
:
if
pool_stride
==
1
:
pooling_transforms
.
append
(
None
)
else
:
pooled_seq_length
=
seq_length
//
pool_stride
pfac
,
sl
,
psl
=
pool_stride
,
seq_length
,
pooled_seq_length
transform
=
[[
1.0
if
(
i
//
pfac
)
==
j
else
0.0
for
j
in
range
(
psl
)]
for
i
in
range
(
sl
)]
transform
=
tf
.
constant
(
transform
,
dtype
=
tf
.
keras
.
mixed_precision
.
global_policy
().
compute_dtype
)
pooling_transforms
.
append
(
transform
/
pool_stride
)
seq_length
=
pooled_seq_length
return
pooling_transforms
def
_create_truncated_avg_masks
(
input_mask
:
tf
.
Tensor
,
pool_strides
:
Sequence
[
int
],
transforms
:
Sequence
[
tf
.
Tensor
]):
"""Computes attention masks.
For [1,1,1,0,0]
Args:
input_mask: Tensor of shape [batch_size, seq_length].
pool_strides: Sequence of pooling strides for each layer.
transforms: Sequnce of off-diagonal matrices filling with 0.0 and
1/pool_stride.
Returns:
attention_masks: Sequence of attention masks for each layer.
"""
def
create_2d_mask
(
from_length
,
mask
):
return
tf
.
einsum
(
'F,BT->BFT'
,
tf
.
ones
([
from_length
],
dtype
=
mask
.
dtype
),
mask
)
attention_masks
=
[]
seq_length
=
tf
.
shape
(
input_mask
)[
-
1
]
layer_mask
=
tf
.
cast
(
input_mask
,
dtype
=
tf
.
keras
.
mixed_precision
.
global_policy
().
compute_dtype
)
for
pool_stride
,
transform
in
zip
(
pool_strides
,
transforms
):
if
pool_stride
==
1
:
attention_masks
.
append
(
create_2d_mask
(
seq_length
,
layer_mask
))
else
:
pooled_seq_length
=
seq_length
//
pool_stride
attention_masks
.
append
(
create_2d_mask
(
pooled_seq_length
,
layer_mask
))
layer_mask
=
tf
.
cast
(
tf
.
einsum
(
'BF,FT->BT'
,
layer_mask
,
transform
)
>
0.0
,
dtype
=
layer_mask
.
dtype
)
seq_length
=
pooled_seq_length
del
seq_length
return
attention_masks
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
FunnelTransformerEncoder
(
tf
.
keras
.
layers
.
Layer
):
class
FunnelTransformerEncoder
(
tf
.
keras
.
layers
.
Layer
):
"""Funnel Transformer-based encoder network.
"""Funnel Transformer-based encoder network.
...
@@ -90,7 +183,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -90,7 +183,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
dropout.
dropout.
attention_dropout: The dropout rate to use for the attention layers within
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
the transformer layers.
pool_type: Pooling type. Choose from ['max', 'avg'].
pool_type: Pooling type. Choose from ['max', 'avg'
, 'truncated_avg'
].
pool_stride: An int or a list of ints. Pooling stride(s) to compress the
pool_stride: An int or a list of ints. Pooling stride(s) to compress the
sequence length. If set to int, each layer will have the same stride size.
sequence length. If set to int, each layer will have the same stride size.
If set to list, the number of elements needs to match num_layers.
If set to list, the number of elements needs to match num_layers.
...
@@ -124,7 +217,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -124,7 +217,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
inner_activation
=
lambda
x
:
tf
.
keras
.
activations
.
gelu
(
x
,
approximate
=
True
),
inner_activation
=
lambda
x
:
tf
.
keras
.
activations
.
gelu
(
x
,
approximate
=
True
),
output_dropout
=
0.1
,
output_dropout
=
0.1
,
attention_dropout
=
0.1
,
attention_dropout
=
0.1
,
pool_type
=
'max'
,
pool_type
=
_MAX
,
pool_stride
=
2
,
pool_stride
=
2
,
unpool_length
=
0
,
unpool_length
=
0
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
...
@@ -207,12 +300,21 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -207,12 +300,21 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
raise
ValueError
(
'Lengths of pool_stride and num_layers are not equal.'
)
raise
ValueError
(
'Lengths of pool_stride and num_layers are not equal.'
)
pool_strides
=
pool_stride
pool_strides
=
pool_stride
# TODO(crickwu): explore tf.keras.layers.serialize method.
# TODO(crickwu): explore tf.keras.layers.serialize method.
if
pool_type
==
'max'
:
if
pool_type
==
_MAX
:
pool_cls
=
tf
.
keras
.
layers
.
MaxPooling1D
pool_cls
=
tf
.
keras
.
layers
.
MaxPooling1D
elif
pool_type
==
'avg'
:
elif
pool_type
==
_AVG
:
pool_cls
=
tf
.
keras
.
layers
.
AveragePooling1D
pool_cls
=
tf
.
keras
.
layers
.
AveragePooling1D
elif
pool_type
==
_TRUNCATED_AVG
:
# TODO(b/203665205): unpool_length should be implemented.
if
unpool_length
!=
0
:
raise
ValueError
(
'unpool_length is not supported by truncated_avg now.'
)
# Compute the attention masks and pooling transforms.
self
.
_pooling_transforms
=
_create_truncated_avg_transforms
(
max_sequence_length
,
pool_strides
)
else
:
else
:
raise
ValueError
(
'pool_type not supported.'
)
raise
ValueError
(
'pool_type not supported.'
)
if
pool_type
in
(
_MAX
,
_AVG
):
self
.
_att_input_pool_layers
=
[]
self
.
_att_input_pool_layers
=
[]
for
layer_pool_stride
in
pool_strides
:
for
layer_pool_stride
in
pool_strides
:
att_input_pool_layer
=
pool_cls
(
att_input_pool_layer
=
pool_cls
(
...
@@ -224,6 +326,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -224,6 +326,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self
.
_pool_strides
=
pool_strides
# This is a list here.
self
.
_pool_strides
=
pool_strides
# This is a list here.
self
.
_unpool_length
=
unpool_length
self
.
_unpool_length
=
unpool_length
self
.
_pool_type
=
pool_type
self
.
_config
=
{
self
.
_config
=
{
'vocab_size'
:
vocab_size
,
'vocab_size'
:
vocab_size
,
...
@@ -280,11 +383,13 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -280,11 +383,13 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
encoder_outputs
=
[]
encoder_outputs
=
[]
x
=
embeddings
x
=
embeddings
# TODO(b/195972228): attention_mask can be co-generated with pooling.
# TODO(b/195972228): attention_mask can be co-generated with pooling.
if
self
.
_pool_type
in
(
_MAX
,
_AVG
):
attention_mask
=
_pool_and_concat
(
attention_mask
=
_pool_and_concat
(
attention_mask
,
attention_mask
,
unpool_length
=
self
.
_unpool_length
,
unpool_length
=
self
.
_unpool_length
,
strides
=
self
.
_pool_strides
[
0
],
strides
=
self
.
_pool_strides
[
0
],
axes
=
[
1
])
axes
=
[
1
])
for
i
,
layer
in
enumerate
(
self
.
_transformer_layers
):
for
i
,
layer
in
enumerate
(
self
.
_transformer_layers
):
# Bypass no pooling cases.
# Bypass no pooling cases.
if
self
.
_pool_strides
[
i
]
==
1
:
if
self
.
_pool_strides
[
i
]
==
1
:
...
@@ -307,12 +412,36 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -307,12 +412,36 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
strides
=
[
self
.
_pool_strides
[
i
+
1
],
self
.
_pool_strides
[
i
]],
strides
=
[
self
.
_pool_strides
[
i
+
1
],
self
.
_pool_strides
[
i
]],
axes
=
[
1
,
2
])
axes
=
[
1
,
2
])
encoder_outputs
.
append
(
x
)
encoder_outputs
.
append
(
x
)
elif
self
.
_pool_type
==
_TRUNCATED_AVG
:
attention_masks
=
_create_truncated_avg_masks
(
mask
,
self
.
_pool_strides
,
self
.
_pooling_transforms
)
for
i
,
layer
in
enumerate
(
self
.
_transformer_layers
):
attention_mask
=
attention_masks
[
i
]
# Bypass no pooling cases.
if
self
.
_pool_strides
[
i
]
==
1
:
x
=
layer
([
x
,
x
,
attention_mask
])
else
:
pooled_inputs
=
tf
.
einsum
(
'BFD,FT->BTD'
,
tf
.
cast
(
x
[:,
self
.
_unpool_length
:,
:],
tf
.
keras
.
mixed_precision
.
global_policy
().
compute_dtype
),
# extra casting for faster mixed computation.
self
.
_pooling_transforms
[
i
])
query_inputs
=
tf
.
concat
(
values
=
(
tf
.
cast
(
x
[:,
:
self
.
_unpool_length
,
:],
dtype
=
pooled_inputs
.
dtype
),
pooled_inputs
),
axis
=
1
)
x
=
layer
([
query_inputs
,
x
,
attention_mask
])
encoder_outputs
.
append
(
x
)
last_encoder_output
=
encoder_outputs
[
-
1
]
last_encoder_output
=
encoder_outputs
[
-
1
]
first_token_tensor
=
last_encoder_output
[:,
0
,
:]
first_token_tensor
=
last_encoder_output
[:,
0
,
:]
pooled_output
=
self
.
_pooler_layer
(
first_token_tensor
)
pooled_output
=
self
.
_pooler_layer
(
first_token_tensor
)
return
dict
(
return
dict
(
word_embeddings
=
word_embeddings
,
embedding_output
=
embeddings
,
sequence_output
=
encoder_outputs
[
-
1
],
sequence_output
=
encoder_outputs
[
-
1
],
pooled_output
=
pooled_output
,
pooled_output
=
pooled_output
,
encoder_outputs
=
encoder_outputs
)
encoder_outputs
=
encoder_outputs
)
...
...
official/nlp/modeling/networks/funnel_transformer_test.py
View file @
52222935
...
@@ -38,6 +38,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -38,6 +38,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
tf
.
keras
.
mixed_precision
.
set_global_policy
(
"float32"
)
tf
.
keras
.
mixed_precision
.
set_global_policy
(
"float32"
)
@
parameterized
.
named_parameters
(
@
parameterized
.
named_parameters
(
(
"mix_truncated_avg"
,
"mixed_float16"
,
tf
.
float16
,
"truncated_avg"
),
(
"float32_truncated_avg"
,
"float32"
,
tf
.
float32
,
"truncated_avg"
),
(
"mix_max"
,
"mixed_float16"
,
tf
.
float16
,
"max"
),
(
"mix_max"
,
"mixed_float16"
,
tf
.
float16
,
"max"
),
(
"float32_max"
,
"float32"
,
tf
.
float32
,
"max"
),
(
"float32_max"
,
"float32"
,
tf
.
float32
,
"max"
),
(
"mix_avg"
,
"mixed_float16"
,
tf
.
float16
,
"avg"
),
(
"mix_avg"
,
"mixed_float16"
,
tf
.
float16
,
"avg"
),
...
@@ -57,6 +59,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -57,6 +59,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
num_layers
=
num_layers
,
num_layers
=
num_layers
,
pool_stride
=
pool_stride
,
pool_stride
=
pool_stride
,
pool_type
=
pool_type
,
pool_type
=
pool_type
,
max_sequence_length
=
sequence_length
,
unpool_length
=
0
)
unpool_length
=
0
)
# Create the inputs (note that the first dimension is implicit).
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
...
@@ -71,8 +74,14 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -71,8 +74,14 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertIsInstance
(
test_network
.
pooler_layer
,
tf
.
keras
.
layers
.
Dense
)
self
.
assertIsInstance
(
test_network
.
pooler_layer
,
tf
.
keras
.
layers
.
Dense
)
# Stride=2 compresses sequence length to half the size at each layer.
# Stride=2 compresses sequence length to half the size at each layer.
# This configuration gives each layer of seq length: 21->11->6->3.
# For pool_type = max or avg,
# this configuration gives each layer of seq length: 21->11->6->3.
# For pool_type = truncated_avg,
# seq length: 21->10->5->2.
if
pool_type
in
[
"max"
,
"avg"
]:
expected_data_shape
=
[
None
,
3
,
hidden_size
]
expected_data_shape
=
[
None
,
3
,
hidden_size
]
else
:
expected_data_shape
=
[
None
,
2
,
hidden_size
]
expected_pooled_shape
=
[
None
,
hidden_size
]
expected_pooled_shape
=
[
None
,
hidden_size
]
self
.
assertAllEqual
(
expected_data_shape
,
data
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_data_shape
,
data
.
shape
.
as_list
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment