Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
dba24007
Unverified
Commit
dba24007
authored
Mar 19, 2019
by
Haoyu Zhang
Committed by
GitHub
Mar 19, 2019
Browse files
Add config to enable XLA in TF 2.0 (#6406)
parent
04792078
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
8 deletions
+24
-8
official/resnet/keras/keras_cifar_main.py
official/resnet/keras/keras_cifar_main.py
+4
-3
official/resnet/keras/keras_common.py
official/resnet/keras/keras_common.py
+16
-2
official/resnet/keras/keras_imagenet_main.py
official/resnet/keras/keras_imagenet_main.py
+4
-3
No files found.
official/resnet/keras/keras_cifar_main.py
View file @
dba24007
...
...
@@ -98,16 +98,17 @@ def run(flags_obj):
Returns:
Dictionary of training and eval stats.
"""
config
=
keras_common
.
get_config_proto
()
# TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
# Eager is default in tf 2.0 and should not be toggled
if
not
keras_common
.
is_v2_0
():
if
keras_common
.
is_v2_0
():
keras_common
.
set_config_v2
()
else
:
config
=
keras_common
.
get_config_proto_v1
()
if
flags_obj
.
enable_eager
:
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
else
:
sess
=
tf
.
Session
(
config
=
config
)
tf
.
keras
.
backend
.
set_session
(
sess
)
# TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
dtype
==
'fp16'
:
...
...
official/resnet/keras/keras_common.py
View file @
dba24007
...
...
@@ -129,14 +129,14 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
'change learning rate to %s.'
,
self
.
epochs
,
batch
,
lr
)
def
get_config_proto
():
def
get_config_proto
_v1
():
"""Return config proto according to flag settings, or None to use default."""
config
=
None
if
FLAGS
.
enable_xla
:
# TODO(haoyuzhang): Remove this monkey patch when XLA OOM issue is fixed.
_monkey_patch_org_assert_broadcastable
()
config
=
tf
.
ConfigProto
()
config
=
tf
.
compat
.
v1
.
ConfigProto
()
config
.
graph_options
.
optimizer_options
.
global_jit_level
=
(
tf
.
OptimizerOptions
.
ON_2
)
# Disable PinToHostOptimizer in grappler when enabling XLA because it causes
...
...
@@ -146,6 +146,20 @@ def get_config_proto():
return
config
def
set_config_v2
():
"""Config eager context according to flag values using TF 2.0 API."""
if
FLAGS
.
enable_xla
:
# TODO(haoyuzhang): Remove this monkey patch when XLA OOM issue is fixed.
_monkey_patch_org_assert_broadcastable
()
tf
.
config
.
optimizer
.
set_jit
(
True
)
# Disable PinToHostOptimizer in grappler when enabling XLA because it
# causes OOM and performance regression.
tf
.
config
.
optimizer
.
set_experimental_options
(
{
"pin_to_host_optimization"
:
False
}
)
def
get_optimizer
():
"""Returns optimizer to use."""
# The learning_rate is overwritten at the beginning of each step by callback.
...
...
official/resnet/keras/keras_imagenet_main.py
View file @
dba24007
...
...
@@ -90,16 +90,17 @@ def run(flags_obj):
Returns:
Dictionary of training and eval stats.
"""
config
=
keras_common
.
get_config_proto
()
# TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
# Eager is default in tf 2.0 and should not be toggled
if
not
keras_common
.
is_v2_0
():
if
keras_common
.
is_v2_0
():
keras_common
.
set_config_v2
()
else
:
config
=
keras_common
.
get_config_proto_v1
()
if
flags_obj
.
enable_eager
:
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
else
:
sess
=
tf
.
Session
(
config
=
config
)
tf
.
keras
.
backend
.
set_session
(
sess
)
# TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
dtype
==
'float16'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment