Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
350f4854
Commit
350f4854
authored
Sep 22, 2021
by
Yuexin Wu
Committed by
A. Unique TensorFlower
Sep 22, 2021
Browse files
Internal change
PiperOrigin-RevId: 398286699
parent
9d5a1a76
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
25 deletions
+63
-25
official/nlp/modeling/networks/funnel_transformer.py
official/nlp/modeling/networks/funnel_transformer.py
+41
-20
official/nlp/modeling/networks/funnel_transformer_test.py
official/nlp/modeling/networks/funnel_transformer_test.py
+22
-5
No files found.
official/nlp/modeling/networks/funnel_transformer.py
View file @
350f4854
...
@@ -14,15 +14,16 @@
...
@@ -14,15 +14,16 @@
"""Funnel Transformer network."""
"""Funnel Transformer network."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
from
typing
import
Union
,
Collection
from
typing
import
Union
,
Sequence
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp
import
keras_nlp
from
official.nlp
import
keras_nlp
def
_pool_and_concat
(
data
,
unpool_length
:
int
,
stride
:
int
,
def
_pool_and_concat
(
data
,
unpool_length
:
int
,
strides
:
Union
[
Sequence
[
int
],
axes
:
Union
[
Collection
[
int
],
int
]):
int
],
axes
:
Union
[
Sequence
[
int
],
int
]):
"""Pools the data along a given axis with stride.
"""Pools the data along a given axis with stride.
It also skips first unpool_length elements.
It also skips first unpool_length elements.
...
@@ -30,7 +31,7 @@ def _pool_and_concat(data, unpool_length: int, stride: int,
...
@@ -30,7 +31,7 @@ def _pool_and_concat(data, unpool_length: int, stride: int,
Args:
Args:
data: Tensor to be pooled.
data: Tensor to be pooled.
unpool_length: Leading elements to be skipped.
unpool_length: Leading elements to be skipped.
stride: Stride for the given ax
i
s.
stride
s
: Stride
s
for the given ax
e
s.
axes: Axes to pool the Tensor.
axes: Axes to pool the Tensor.
Returns:
Returns:
...
@@ -39,8 +40,13 @@ def _pool_and_concat(data, unpool_length: int, stride: int,
...
@@ -39,8 +40,13 @@ def _pool_and_concat(data, unpool_length: int, stride: int,
# Wraps the axes as a list.
# Wraps the axes as a list.
if
isinstance
(
axes
,
int
):
if
isinstance
(
axes
,
int
):
axes
=
[
axes
]
axes
=
[
axes
]
if
isinstance
(
strides
,
int
):
strides
=
[
strides
]
*
len
(
axes
)
else
:
if
len
(
strides
)
!=
len
(
axes
):
raise
ValueError
(
'The lengths of strides and axes need to match.'
)
for
axis
in
axes
:
for
axis
,
stride
in
zip
(
axes
,
strides
)
:
# Skips first `unpool_length` tokens.
# Skips first `unpool_length` tokens.
unpool_tensor_shape
=
[
slice
(
None
)]
*
axis
+
[
slice
(
None
,
unpool_length
)]
unpool_tensor_shape
=
[
slice
(
None
)]
*
axis
+
[
slice
(
None
,
unpool_length
)]
unpool_tensor
=
data
[
unpool_tensor_shape
]
unpool_tensor
=
data
[
unpool_tensor_shape
]
...
@@ -80,7 +86,9 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -80,7 +86,9 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
dropout.
dropout.
attention_dropout: The dropout rate to use for the attention layers within
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
the transformer layers.
pool_stride: Pooling stride to compress the sequence length.
pool_stride: An int or a list of ints. Pooling stride(s) to compress the
sequence length. If set to int, each layer will have the same stride size.
If set to list, the number of elements needs to match num_layers.
unpool_length: Leading n tokens to be skipped from pooling.
unpool_length: Leading n tokens to be skipped from pooling.
initializer: The initialzer to use for all weights in this encoder.
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
output_range: The sequence output range, [0, output_range), by slicing the
...
@@ -185,12 +193,23 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -185,12 +193,23 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
activation
=
'tanh'
,
activation
=
'tanh'
,
kernel_initializer
=
initializer
,
kernel_initializer
=
initializer
,
name
=
'pooler_transform'
)
name
=
'pooler_transform'
)
self
.
_att_input_pool_layer
=
tf
.
keras
.
layers
.
MaxPooling1D
(
if
isinstance
(
pool_stride
,
int
):
pool_size
=
pool_stride
,
# TODO(b/197133196): Pooling layer can be shared.
strides
=
pool_stride
,
pool_strides
=
[
pool_stride
]
*
num_layers
padding
=
'same'
,
else
:
name
=
'att_input_pool_layer'
)
if
len
(
pool_stride
)
!=
num_layers
:
self
.
_pool_stride
=
pool_stride
raise
ValueError
(
'Lengths of pool_stride and num_layers are not equal.'
)
pool_strides
=
pool_stride
self
.
_att_input_pool_layers
=
[]
for
layer_pool_stride
in
pool_strides
:
att_input_pool_layer
=
tf
.
keras
.
layers
.
MaxPooling1D
(
pool_size
=
layer_pool_stride
,
strides
=
layer_pool_stride
,
padding
=
'same'
,
name
=
'att_input_pool_layer'
)
self
.
_att_input_pool_layers
.
append
(
att_input_pool_layer
)
self
.
_pool_strides
=
pool_strides
# This is a list here.
self
.
_unpool_length
=
unpool_length
self
.
_unpool_length
=
unpool_length
self
.
_config
=
{
self
.
_config
=
{
...
@@ -250,11 +269,12 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -250,11 +269,12 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
attention_mask
=
_pool_and_concat
(
attention_mask
=
_pool_and_concat
(
attention_mask
,
attention_mask
,
unpool_length
=
self
.
_unpool_length
,
unpool_length
=
self
.
_unpool_length
,
stride
=
self
.
_pool_stride
,
stride
s
=
self
.
_pool_stride
s
[
0
]
,
axes
=
[
1
])
axes
=
[
1
])
for
layer
in
self
.
_transformer_layers
:
for
i
,
layer
in
enumerate
(
self
.
_transformer_layers
)
:
# Pools layer for compressing the query length.
# Pools layer for compressing the query length.
pooled_inputs
=
self
.
_att_input_pool_layer
(
x
[:,
self
.
_unpool_length
:,
:])
pooled_inputs
=
self
.
_att_input_pool_layers
[
i
](
x
[:,
self
.
_unpool_length
:,
:])
query_inputs
=
tf
.
concat
(
query_inputs
=
tf
.
concat
(
values
=
(
tf
.
cast
(
values
=
(
tf
.
cast
(
x
[:,
:
self
.
_unpool_length
,
:],
x
[:,
:
self
.
_unpool_length
,
:],
...
@@ -262,11 +282,12 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -262,11 +282,12 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
axis
=
1
)
axis
=
1
)
x
=
layer
([
query_inputs
,
x
,
attention_mask
])
x
=
layer
([
query_inputs
,
x
,
attention_mask
])
# Pools the corresponding attention_mask.
# Pools the corresponding attention_mask.
attention_mask
=
_pool_and_concat
(
if
i
<
len
(
self
.
_transformer_layers
)
-
1
:
attention_mask
,
attention_mask
=
_pool_and_concat
(
unpool_length
=
self
.
_unpool_length
,
attention_mask
,
stride
=
self
.
_pool_stride
,
unpool_length
=
self
.
_unpool_length
,
axes
=
[
1
,
2
])
strides
=
[
self
.
_pool_strides
[
i
+
1
],
self
.
_pool_strides
[
i
]],
axes
=
[
1
,
2
])
encoder_outputs
.
append
(
x
)
encoder_outputs
.
append
(
x
)
last_encoder_output
=
encoder_outputs
[
-
1
]
last_encoder_output
=
encoder_outputs
[
-
1
]
...
...
official/nlp/modeling/networks/funnel_transformer_test.py
View file @
350f4854
...
@@ -80,8 +80,24 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -80,8 +80,24 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertAllEqual
(
tf
.
float32
,
data
.
dtype
)
self
.
assertAllEqual
(
tf
.
float32
,
data
.
dtype
)
self
.
assertAllEqual
(
pooled_dtype
,
pooled
.
dtype
)
self
.
assertAllEqual
(
pooled_dtype
,
pooled
.
dtype
)
def
test_invalid_stride_and_num_layers
(
self
):
hidden_size
=
32
num_layers
=
3
pool_stride
=
[
2
,
2
]
unpool_length
=
1
with
self
.
assertRaisesRegex
(
ValueError
,
"pool_stride and num_layers are not equal"
):
_
=
funnel_transformer
.
FunnelTransformerEncoder
(
vocab_size
=
100
,
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
num_layers
=
num_layers
,
pool_stride
=
pool_stride
,
unpool_length
=
unpool_length
)
@
parameterized
.
named_parameters
(
@
parameterized
.
named_parameters
(
(
"no_stride_no_unpool"
,
1
,
0
),
(
"no_stride_no_unpool"
,
1
,
0
),
(
"stride_list_with_unpool"
,
[
2
,
3
,
4
],
1
),
(
"large_stride_with_unpool"
,
3
,
1
),
(
"large_stride_with_unpool"
,
3
,
1
),
(
"large_stride_with_large_unpool"
,
5
,
10
),
(
"large_stride_with_large_unpool"
,
5
,
10
),
(
"no_stride_with_unpool"
,
1
,
1
),
(
"no_stride_with_unpool"
,
1
,
1
),
...
@@ -110,11 +126,12 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -110,11 +126,12 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
expected_data_shape
=
[
None
,
sequence_length
,
hidden_size
]
expected_data_shape
=
[
None
,
sequence_length
,
hidden_size
]
expected_pooled_shape
=
[
None
,
hidden_size
]
expected_pooled_shape
=
[
None
,
hidden_size
]
self
.
assertLen
(
all_encoder_outputs
,
num_layers
)
self
.
assertLen
(
all_encoder_outputs
,
num_layers
)
for
data
in
all_encoder_outputs
:
if
isinstance
(
pool_stride
,
int
):
expected_data_shape
[
1
]
=
unpool_length
+
(
expected_data_shape
[
1
]
+
pool_stride
=
[
pool_stride
]
*
num_layers
pool_stride
-
1
-
for
layer_pool_stride
,
data
in
zip
(
pool_stride
,
all_encoder_outputs
):
unpool_length
)
//
pool_stride
expected_data_shape
[
1
]
=
unpool_length
+
(
print
(
"shapes:"
,
expected_data_shape
,
data
.
shape
.
as_list
())
expected_data_shape
[
1
]
+
layer_pool_stride
-
1
-
unpool_length
)
//
layer_pool_stride
self
.
assertAllEqual
(
expected_data_shape
,
data
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_data_shape
,
data
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_pooled_shape
,
pooled
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_pooled_shape
,
pooled
.
shape
.
as_list
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment