Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
96a8d744
Commit
96a8d744
authored
May 03, 2022
by
Yuexin Wu
Committed by
A. Unique TensorFlower
May 03, 2022
Browse files
Internal change
PiperOrigin-RevId: 446242527
parent
c7734283
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
5 deletions
+8
-5
official/nlp/modeling/networks/funnel_transformer.py
official/nlp/modeling/networks/funnel_transformer.py
+8
-5
No files found.
official/nlp/modeling/networks/funnel_transformer.py
View file @
96a8d744
...
@@ -343,9 +343,6 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -343,9 +343,6 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
# TODO(b/203665205): unpool_length should be implemented.
# TODO(b/203665205): unpool_length should be implemented.
if
unpool_length
!=
0
:
if
unpool_length
!=
0
:
raise
ValueError
(
'unpool_length is not supported by truncated_avg now.'
)
raise
ValueError
(
'unpool_length is not supported by truncated_avg now.'
)
# Compute the attention masks and pooling transforms.
self
.
_pooling_transforms
=
_create_truncated_avg_transforms
(
max_sequence_length
,
pool_strides
)
else
:
else
:
raise
ValueError
(
'pool_type not supported.'
)
raise
ValueError
(
'pool_type not supported.'
)
...
@@ -359,6 +356,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -359,6 +356,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
name
=
'att_input_pool_layer'
)
name
=
'att_input_pool_layer'
)
self
.
_att_input_pool_layers
.
append
(
att_input_pool_layer
)
self
.
_att_input_pool_layers
.
append
(
att_input_pool_layer
)
self
.
_max_sequence_length
=
max_sequence_length
self
.
_pool_strides
=
pool_strides
# This is a list here.
self
.
_pool_strides
=
pool_strides
# This is a list here.
self
.
_unpool_length
=
unpool_length
self
.
_unpool_length
=
unpool_length
self
.
_pool_type
=
pool_type
self
.
_pool_type
=
pool_type
...
@@ -489,8 +487,13 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -489,8 +487,13 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
axes
=
[
1
,
2
])
axes
=
[
1
,
2
])
encoder_outputs
.
append
(
x
)
encoder_outputs
.
append
(
x
)
elif
self
.
_pool_type
==
_TRUNCATED_AVG
:
elif
self
.
_pool_type
==
_TRUNCATED_AVG
:
# Compute the attention masks and pooling transforms.
# Note we do not compute this in __init__ due to inference converter issue
# b/215659399.
pooling_transforms
=
_create_truncated_avg_transforms
(
self
.
_max_sequence_length
,
self
.
_pool_strides
)
attention_masks
=
_create_truncated_avg_masks
(
mask
,
self
.
_pool_strides
,
attention_masks
=
_create_truncated_avg_masks
(
mask
,
self
.
_pool_strides
,
self
.
_
pooling_transforms
)
pooling_transforms
)
for
i
,
layer
in
enumerate
(
self
.
_transformer_layers
):
for
i
,
layer
in
enumerate
(
self
.
_transformer_layers
):
attention_mask
=
attention_masks
[
i
]
attention_mask
=
attention_masks
[
i
]
# Bypass no pooling cases.
# Bypass no pooling cases.
...
@@ -501,7 +504,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -501,7 +504,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
'BFD,FT->BTD'
,
'BFD,FT->BTD'
,
tf
.
cast
(
x
[:,
self
.
_unpool_length
:,
:],
_get_policy_dtype
()
tf
.
cast
(
x
[:,
self
.
_unpool_length
:,
:],
_get_policy_dtype
()
),
# extra casting for faster mixed computation.
),
# extra casting for faster mixed computation.
self
.
_
pooling_transforms
[
i
])
pooling_transforms
[
i
])
query_inputs
=
tf
.
concat
(
query_inputs
=
tf
.
concat
(
values
=
(
tf
.
cast
(
values
=
(
tf
.
cast
(
x
[:,
:
self
.
_unpool_length
,
:],
x
[:,
:
self
.
_unpool_length
,
:],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment