Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a2a1b66f
Commit
a2a1b66f
authored
Nov 20, 2019
by
A. Unique TensorFlower
Browse files
Move distribution stragegy init to beginning of run
PiperOrigin-RevId: 281641189
parent
92384c60
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
16 deletions
+16
-16
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+16
-16
No files found.
official/transformer/v2/transformer_main.py
View file @
a2a1b66f
...
...
@@ -160,22 +160,6 @@ class TransformerTask(object):
params
[
"dtype"
]
=
flags_core
.
get_tf_dtype
(
flags_obj
)
params
[
"enable_metrics_in_training"
]
=
flags_obj
.
enable_metrics_in_training
if
params
[
"dtype"
]
==
tf
.
float16
:
# TODO(reedwm): It's pretty ugly to set the global policy in a constructor
# like this. What if multiple instances of TransformerTask are created?
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
"dynamic"
)
policy
=
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"mixed_float16"
,
loss_scale
=
loss_scale
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
if
params
[
"dtype"
]
==
tf
.
bfloat16
:
policy
=
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"mixed_bfloat16"
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
self
.
distribution_strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
flags_obj
.
distribution_strategy
,
num_gpus
=
num_gpus
,
...
...
@@ -193,6 +177,22 @@ class TransformerTask(object):
else
:
logging
.
info
(
"Not using any distribution strategy."
)
if
params
[
"dtype"
]
==
tf
.
float16
:
# TODO(reedwm): It's pretty ugly to set the global policy in a constructor
# like this. What if multiple instances of TransformerTask are created?
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
"dynamic"
)
policy
=
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"mixed_float16"
,
loss_scale
=
loss_scale
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
if
params
[
"dtype"
]
==
tf
.
bfloat16
:
policy
=
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"mixed_bfloat16"
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
@
property
def
use_tpu
(
self
):
if
self
.
distribution_strategy
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment