Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
350f4854
Commit
350f4854
authored
Sep 22, 2021
by
Yuexin Wu
Committed by
A. Unique TensorFlower
Sep 22, 2021
Browse files
Internal change
PiperOrigin-RevId: 398286699
parent
9d5a1a76
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
25 deletions
+63
-25
official/nlp/modeling/networks/funnel_transformer.py
official/nlp/modeling/networks/funnel_transformer.py
+41
-20
official/nlp/modeling/networks/funnel_transformer_test.py
official/nlp/modeling/networks/funnel_transformer_test.py
+22
-5
No files found.
official/nlp/modeling/networks/funnel_transformer.py
View file @
350f4854
...
...
@@ -14,15 +14,16 @@
"""Funnel Transformer network."""
# pylint: disable=g-classes-have-attributes
from
typing
import
Union
,
Collection
from
typing
import
Union
,
Sequence
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp
import
keras_nlp
def
_pool_and_concat
(
data
,
unpool_length
:
int
,
stride
:
int
,
axes
:
Union
[
Collection
[
int
],
int
]):
def
_pool_and_concat
(
data
,
unpool_length
:
int
,
strides
:
Union
[
Sequence
[
int
],
int
],
axes
:
Union
[
Sequence
[
int
],
int
]):
"""Pools the data along a given axis with stride.
It also skips first unpool_length elements.
...
...
@@ -30,7 +31,7 @@ def _pool_and_concat(data, unpool_length: int, stride: int,
Args:
data: Tensor to be pooled.
unpool_length: Leading elements to be skipped.
stride: Stride for the given ax
i
s.
stride
s
: Stride
s
for the given ax
e
s.
axes: Axes to pool the Tensor.
Returns:
...
...
@@ -39,8 +40,13 @@ def _pool_and_concat(data, unpool_length: int, stride: int,
# Wraps the axes as a list.
if
isinstance
(
axes
,
int
):
axes
=
[
axes
]
if
isinstance
(
strides
,
int
):
strides
=
[
strides
]
*
len
(
axes
)
else
:
if
len
(
strides
)
!=
len
(
axes
):
raise
ValueError
(
'The lengths of strides and axes need to match.'
)
for
axis
in
axes
:
for
axis
,
stride
in
zip
(
axes
,
strides
)
:
# Skips first `unpool_length` tokens.
unpool_tensor_shape
=
[
slice
(
None
)]
*
axis
+
[
slice
(
None
,
unpool_length
)]
unpool_tensor
=
data
[
unpool_tensor_shape
]
...
...
@@ -80,7 +86,9 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
pool_stride: Pooling stride to compress the sequence length.
pool_stride: An int or a list of ints. Pooling stride(s) to compress the
sequence length. If set to int, each layer will have the same stride size.
If set to list, the number of elements needs to match num_layers.
unpool_length: Leading n tokens to be skipped from pooling.
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
...
...
@@ -185,12 +193,23 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
activation
=
'tanh'
,
kernel_initializer
=
initializer
,
name
=
'pooler_transform'
)
self
.
_att_input_pool_layer
=
tf
.
keras
.
layers
.
MaxPooling1D
(
pool_size
=
pool_stride
,
strides
=
pool_stride
,
if
isinstance
(
pool_stride
,
int
):
# TODO(b/197133196): Pooling layer can be shared.
pool_strides
=
[
pool_stride
]
*
num_layers
else
:
if
len
(
pool_stride
)
!=
num_layers
:
raise
ValueError
(
'Lengths of pool_stride and num_layers are not equal.'
)
pool_strides
=
pool_stride
self
.
_att_input_pool_layers
=
[]
for
layer_pool_stride
in
pool_strides
:
att_input_pool_layer
=
tf
.
keras
.
layers
.
MaxPooling1D
(
pool_size
=
layer_pool_stride
,
strides
=
layer_pool_stride
,
padding
=
'same'
,
name
=
'att_input_pool_layer'
)
self
.
_pool_stride
=
pool_stride
self
.
_att_input_pool_layers
.
append
(
att_input_pool_layer
)
self
.
_pool_strides
=
pool_strides
# This is a list here.
self
.
_unpool_length
=
unpool_length
self
.
_config
=
{
...
...
@@ -250,11 +269,12 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
attention_mask
=
_pool_and_concat
(
attention_mask
,
unpool_length
=
self
.
_unpool_length
,
stride
=
self
.
_pool_stride
,
stride
s
=
self
.
_pool_stride
s
[
0
]
,
axes
=
[
1
])
for
layer
in
self
.
_transformer_layers
:
for
i
,
layer
in
enumerate
(
self
.
_transformer_layers
)
:
# Pools layer for compressing the query length.
pooled_inputs
=
self
.
_att_input_pool_layer
(
x
[:,
self
.
_unpool_length
:,
:])
pooled_inputs
=
self
.
_att_input_pool_layers
[
i
](
x
[:,
self
.
_unpool_length
:,
:])
query_inputs
=
tf
.
concat
(
values
=
(
tf
.
cast
(
x
[:,
:
self
.
_unpool_length
,
:],
...
...
@@ -262,10 +282,11 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
axis
=
1
)
x
=
layer
([
query_inputs
,
x
,
attention_mask
])
# Pools the corresponding attention_mask.
if
i
<
len
(
self
.
_transformer_layers
)
-
1
:
attention_mask
=
_pool_and_concat
(
attention_mask
,
unpool_length
=
self
.
_unpool_length
,
stride
=
self
.
_pool_stride
,
stride
s
=
[
self
.
_pool_stride
s
[
i
+
1
],
self
.
_pool_strides
[
i
]]
,
axes
=
[
1
,
2
])
encoder_outputs
.
append
(
x
)
...
...
official/nlp/modeling/networks/funnel_transformer_test.py
View file @
350f4854
...
...
@@ -80,8 +80,24 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertAllEqual
(
tf
.
float32
,
data
.
dtype
)
self
.
assertAllEqual
(
pooled_dtype
,
pooled
.
dtype
)
def
test_invalid_stride_and_num_layers
(
self
):
hidden_size
=
32
num_layers
=
3
pool_stride
=
[
2
,
2
]
unpool_length
=
1
with
self
.
assertRaisesRegex
(
ValueError
,
"pool_stride and num_layers are not equal"
):
_
=
funnel_transformer
.
FunnelTransformerEncoder
(
vocab_size
=
100
,
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
num_layers
=
num_layers
,
pool_stride
=
pool_stride
,
unpool_length
=
unpool_length
)
@
parameterized
.
named_parameters
(
(
"no_stride_no_unpool"
,
1
,
0
),
(
"stride_list_with_unpool"
,
[
2
,
3
,
4
],
1
),
(
"large_stride_with_unpool"
,
3
,
1
),
(
"large_stride_with_large_unpool"
,
5
,
10
),
(
"no_stride_with_unpool"
,
1
,
1
),
...
...
@@ -110,11 +126,12 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
expected_data_shape
=
[
None
,
sequence_length
,
hidden_size
]
expected_pooled_shape
=
[
None
,
hidden_size
]
self
.
assertLen
(
all_encoder_outputs
,
num_layers
)
for
data
in
all_encoder_outputs
:
expected_data_shape
[
1
]
=
unpool_length
+
(
expected_data_shape
[
1
]
+
pool_stride
-
1
-
unpool_length
)
//
pool_stride
print
(
"shapes:"
,
expected_data_shape
,
data
.
shape
.
as_list
())
if
isinstance
(
pool_stride
,
int
):
pool_stride
=
[
pool_stride
]
*
num_layers
for
layer_pool_stride
,
data
in
zip
(
pool_stride
,
all_encoder_outputs
):
expected_data_shape
[
1
]
=
unpool_length
+
(
expected_data_shape
[
1
]
+
layer_pool_stride
-
1
-
unpool_length
)
//
layer_pool_stride
self
.
assertAllEqual
(
expected_data_shape
,
data
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_pooled_shape
,
pooled
.
shape
.
as_list
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment