Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5292d16e
Commit
5292d16e
authored
Nov 02, 2021
by
Yuexin Wu
Committed by
A. Unique TensorFlower
Nov 02, 2021
Browse files
#Funnel fix `mixed_precision=None` error when mixed with TF1 code.
PiperOrigin-RevId: 407253800
parent
e387ed65
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
8 deletions
+11
-8
official/nlp/modeling/networks/funnel_transformer.py
official/nlp/modeling/networks/funnel_transformer.py
+11
-8
No files found.
official/nlp/modeling/networks/funnel_transformer.py
View file @
5292d16e
...
@@ -27,6 +27,13 @@ _AVG = 'avg'
...
@@ -27,6 +27,13 @@ _AVG = 'avg'
_TRUNCATED_AVG
=
'truncated_avg'
_TRUNCATED_AVG
=
'truncated_avg'
def
_get_policy_dtype
():
try
:
return
tf
.
keras
.
mixed_precision
.
global_policy
().
compute_dtype
or
tf
.
float32
except
AttributeError
:
# tf1 has no attribute 'global_policy'
return
tf
.
float32
def
_pool_and_concat
(
mask
,
unpool_length
:
int
,
strides
:
Union
[
Sequence
[
int
],
def
_pool_and_concat
(
mask
,
unpool_length
:
int
,
strides
:
Union
[
Sequence
[
int
],
int
],
int
],
axes
:
Union
[
Sequence
[
int
],
int
]):
axes
:
Union
[
Sequence
[
int
],
int
]):
...
@@ -105,9 +112,7 @@ def _create_truncated_avg_transforms(seq_length: int,
...
@@ -105,9 +112,7 @@ def _create_truncated_avg_transforms(seq_length: int,
transform
=
[[
1.0
if
(
i
//
pfac
)
==
j
else
0.0
transform
=
[[
1.0
if
(
i
//
pfac
)
==
j
else
0.0
for
j
in
range
(
psl
)]
for
j
in
range
(
psl
)]
for
i
in
range
(
sl
)]
for
i
in
range
(
sl
)]
transform
=
tf
.
constant
(
transform
=
tf
.
constant
(
transform
,
dtype
=
_get_policy_dtype
())
transform
,
dtype
=
tf
.
keras
.
mixed_precision
.
global_policy
().
compute_dtype
)
pooling_transforms
.
append
(
transform
/
pool_stride
)
pooling_transforms
.
append
(
transform
/
pool_stride
)
seq_length
=
pooled_seq_length
seq_length
=
pooled_seq_length
...
@@ -125,7 +130,7 @@ def _create_truncated_avg_masks(input_mask: tf.Tensor,
...
@@ -125,7 +130,7 @@ def _create_truncated_avg_masks(input_mask: tf.Tensor,
Args:
Args:
input_mask: Tensor of shape [batch_size, seq_length].
input_mask: Tensor of shape [batch_size, seq_length].
pool_strides: Sequence of pooling strides for each layer.
pool_strides: Sequence of pooling strides for each layer.
transforms: Sequnce of off-diagonal matrices filling with 0.0 and
transforms: Sequ
e
nce of off-diagonal matrices filling with 0.0 and
1/pool_stride.
1/pool_stride.
Returns:
Returns:
...
@@ -138,8 +143,7 @@ def _create_truncated_avg_masks(input_mask: tf.Tensor,
...
@@ -138,8 +143,7 @@ def _create_truncated_avg_masks(input_mask: tf.Tensor,
attention_masks
=
[]
attention_masks
=
[]
seq_length
=
tf
.
shape
(
input_mask
)[
-
1
]
seq_length
=
tf
.
shape
(
input_mask
)[
-
1
]
layer_mask
=
tf
.
cast
(
layer_mask
=
tf
.
cast
(
input_mask
,
dtype
=
_get_policy_dtype
())
input_mask
,
dtype
=
tf
.
keras
.
mixed_precision
.
global_policy
().
compute_dtype
)
for
pool_stride
,
transform
in
zip
(
pool_strides
,
transforms
):
for
pool_stride
,
transform
in
zip
(
pool_strides
,
transforms
):
if
pool_stride
==
1
:
if
pool_stride
==
1
:
attention_masks
.
append
(
create_2d_mask
(
seq_length
,
layer_mask
))
attention_masks
.
append
(
create_2d_mask
(
seq_length
,
layer_mask
))
...
@@ -423,8 +427,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
...
@@ -423,8 +427,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
else
:
else
:
pooled_inputs
=
tf
.
einsum
(
pooled_inputs
=
tf
.
einsum
(
'BFD,FT->BTD'
,
'BFD,FT->BTD'
,
tf
.
cast
(
x
[:,
self
.
_unpool_length
:,
:],
tf
.
cast
(
x
[:,
self
.
_unpool_length
:,
:],
_get_policy_dtype
()
tf
.
keras
.
mixed_precision
.
global_policy
().
compute_dtype
),
# extra casting for faster mixed computation.
),
# extra casting for faster mixed computation.
self
.
_pooling_transforms
[
i
])
self
.
_pooling_transforms
[
i
])
query_inputs
=
tf
.
concat
(
query_inputs
=
tf
.
concat
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment