Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0aba91fc
Commit
0aba91fc
authored
Oct 04, 2021
by
Yuexin Wu
Committed by
A. Unique TensorFlower
Oct 04, 2021
Browse files
Internal change
PiperOrigin-RevId: 400837715
parent
06b2d7d7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
17 deletions
+25
-17
official/nlp/modeling/networks/funnel_transformer.py
official/nlp/modeling/networks/funnel_transformer.py
+25
-17
No files found.
official/nlp/modeling/networks/funnel_transformer.py
View file @
0aba91fc
...
...
@@ -16,20 +16,21 @@
# pylint: disable=g-classes-have-attributes
from
typing
import
Union
,
Sequence
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp
import
keras_nlp
def
_pool_and_concat
(
data
,
unpool_length
:
int
,
strides
:
Union
[
Sequence
[
int
],
def
_pool_and_concat
(
mask
,
unpool_length
:
int
,
strides
:
Union
[
Sequence
[
int
],
int
],
axes
:
Union
[
Sequence
[
int
],
int
]):
"""Pools the
data
along a given axis with stride.
"""Pools the
mask
along a given axis with stride.
It also skips first unpool_length elements.
Args:
data
: Tensor to be pooled.
mask
: Tensor to be pooled.
unpool_length: Leading elements to be skipped.
strides: Strides for the given axes.
axes: Axes to pool the Tensor.
...
...
@@ -45,18 +46,21 @@ def _pool_and_concat(data, unpool_length: int, strides: Union[Sequence[int],
else
:
if
len
(
strides
)
!=
len
(
axes
):
raise
ValueError
(
'The lengths of strides and axes need to match.'
)
# Bypass no pooling cases.
if
np
.
all
(
np
.
array
(
strides
)
==
1
):
return
mask
for
axis
,
stride
in
zip
(
axes
,
strides
):
# Skips first `unpool_length` tokens.
unpool_tensor_shape
=
[
slice
(
None
)]
*
axis
+
[
slice
(
None
,
unpool_length
)]
unpool_tensor
=
data
[
unpool_tensor_shape
]
unpool_tensor
=
mask
[
unpool_tensor_shape
]
# Pools the second half.
pool_tensor_shape
=
[
slice
(
None
)]
*
axis
+
[
slice
(
unpool_length
,
None
,
stride
)
]
pool_tensor
=
data
[
pool_tensor_shape
]
data
=
tf
.
concat
((
unpool_tensor
,
pool_tensor
),
axis
=
axis
)
return
data
pool_tensor
=
mask
[
pool_tensor_shape
]
mask
=
tf
.
concat
((
unpool_tensor
,
pool_tensor
),
axis
=
axis
)
return
mask
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
...
...
@@ -272,6 +276,10 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
strides
=
self
.
_pool_strides
[
0
],
axes
=
[
1
])
for
i
,
layer
in
enumerate
(
self
.
_transformer_layers
):
# Bypass no pooling cases.
if
self
.
_pool_strides
[
i
]
==
1
:
x
=
layer
([
x
,
x
,
attention_mask
])
else
:
# Pools layer for compressing the query length.
pooled_inputs
=
self
.
_att_input_pool_layers
[
i
](
x
[:,
self
.
_unpool_length
:,
:])
...
...
@@ -286,7 +294,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
attention_mask
=
_pool_and_concat
(
attention_mask
,
unpool_length
=
self
.
_unpool_length
,
strides
=
[
self
.
_pool_strides
[
i
+
1
],
self
.
_pool_strides
[
i
]],
strides
=
[
self
.
_pool_strides
[
i
+
1
],
self
.
_pool_strides
[
i
]],
axes
=
[
1
,
2
])
encoder_outputs
.
append
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment