Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b578aee9
Commit
b578aee9
authored
Jun 20, 2019
by
Reed
Committed by
Toby Boyd
Jun 20, 2019
Browse files
Fix Transformer Perfzero issue with fp16 (#7074)
parent
adc8f11b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
11 deletions
+9
-11
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+9
-5
official/transformer/v2/transformer_main_test.py
official/transformer/v2/transformer_main_test.py
+0
-6
No files found.
official/transformer/v2/transformer_main.py
View file @
b578aee9
...
...
@@ -120,6 +120,15 @@ class TransformerTask(object):
params
[
"repeat_dataset"
]
=
None
params
[
"dtype"
]
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
params
[
"dtype"
]
==
tf
.
float16
:
# TODO(reedwm): It's pretty ugly to set the global policy in a constructor
# like this. What if multiple instances of TransformerTask are created?
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'infer_float32_vars'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
def
train
(
self
):
"""Trains the model."""
params
,
flags_obj
,
is_train
=
self
.
params
,
self
.
flags_obj
,
True
...
...
@@ -263,11 +272,6 @@ def _ensure_dir(log_dir):
def
main
(
_
):
flags_obj
=
flags
.
FLAGS
with
logger
.
benchmark_context
(
flags_obj
):
if
flags_core
.
get_tf_dtype
(
flags_obj
)
==
'float16'
:
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'infer_float32_vars'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
task
=
TransformerTask
(
flags_obj
)
if
flags_obj
.
mode
==
"train"
:
task
.
train
()
...
...
official/transformer/v2/transformer_main_test.py
View file @
b578aee9
...
...
@@ -92,9 +92,6 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS
.
num_gpus
=
2
FLAGS
.
param_set
=
"base"
FLAGS
.
dtype
=
"fp16"
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'infer_float32_vars'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
train
()
...
...
@@ -131,9 +128,6 @@ class TransformerTaskTest(tf.test.TestCase):
def
test_predict_fp16
(
self
):
self
.
_prepare_files_and_flags
(
"--dtype=fp16"
)
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'infer_float32_vars'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
predict
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment