Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
17e923da
Unverified
Commit
17e923da
authored
Apr 03, 2019
by
Reed
Committed by
GitHub
Apr 03, 2019
Browse files
Add dynamic loss scaling support (#6518)
parent
cc9eef76
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
61 additions
and
22 deletions
+61
-22
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+3
-2
official/resnet/keras/keras_imagenet_main.py
official/resnet/keras/keras_imagenet_main.py
+1
-1
official/resnet/resnet_run_loop.py
official/resnet/resnet_run_loop.py
+3
-2
official/utils/flags/_performance.py
official/utils/flags/_performance.py
+42
-15
official/utils/flags/flags_test.py
official/utils/flags/flags_test.py
+10
-2
official/utils/flags/guidelines.md
official/utils/flags/guidelines.md
+2
-0
No files found.
official/resnet/imagenet_main.py
View file @
17e923da
...
...
@@ -343,9 +343,10 @@ def imagenet_model_fn(features, labels, mode, params):
)
def
define_imagenet_flags
():
def
define_imagenet_flags
(
dynamic_loss_scale
=
False
):
resnet_run_loop
.
define_resnet_flags
(
resnet_size_choices
=
[
'18'
,
'34'
,
'50'
,
'101'
,
'152'
,
'200'
])
resnet_size_choices
=
[
'18'
,
'34'
,
'50'
,
'101'
,
'152'
,
'200'
],
dynamic_loss_scale
=
dynamic_loss_scale
)
flags
.
adopt_module_key_flags
(
resnet_run_loop
)
flags_core
.
set_defaults
(
train_epochs
=
90
)
...
...
official/resnet/keras/keras_imagenet_main.py
View file @
17e923da
...
...
@@ -235,6 +235,6 @@ def main(_):
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
imagenet_main
.
define_imagenet_flags
()
imagenet_main
.
define_imagenet_flags
(
dynamic_loss_scale
=
True
)
keras_common
.
define_keras_flags
()
absl_app
.
run
(
main
)
official/resnet/resnet_run_loop.py
View file @
17e923da
...
...
@@ -707,13 +707,14 @@ def resnet_main(
return
stats
def
define_resnet_flags
(
resnet_size_choices
=
None
):
def
define_resnet_flags
(
resnet_size_choices
=
None
,
dynamic_loss_scale
=
False
):
"""Add flags and validators for ResNet."""
flags_core
.
define_base
()
flags_core
.
define_performance
(
num_parallel_calls
=
False
,
tf_gpu_thread_mode
=
True
,
datasets_num_private_threads
=
True
,
datasets_num_parallel_batches
=
True
)
datasets_num_parallel_batches
=
True
,
dynamic_loss_scale
=
dynamic_loss_scale
)
flags_core
.
define_image
()
flags_core
.
define_benchmark
()
flags
.
adopt_module_key_flags
(
flags_core
)
...
...
official/utils/flags/_performance.py
View file @
17e923da
...
...
@@ -38,8 +38,10 @@ def get_tf_dtype(flags_obj):
def
get_loss_scale
(
flags_obj
):
if
flags_obj
.
loss_scale
is
not
None
:
if
flags_obj
.
loss_scale
==
"dynamic"
:
return
flags_obj
.
loss_scale
elif
flags_obj
.
loss_scale
is
not
None
:
return
float
(
flags_obj
.
loss_scale
)
return
DTYPE_MAP
[
flags_obj
.
dtype
][
1
]
...
...
@@ -47,7 +49,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
synthetic_data
=
True
,
max_train_steps
=
True
,
dtype
=
True
,
all_reduce_alg
=
True
,
tf_gpu_thread_mode
=
False
,
datasets_num_private_threads
=
False
,
datasets_num_parallel_batches
=
False
):
datasets_num_parallel_batches
=
False
,
dynamic_loss_scale
=
False
):
"""Register flags for specifying performance tuning arguments.
Args:
...
...
@@ -63,6 +66,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
datasets_num_private_threads: Number of private threads for datasets.
datasets_num_parallel_batches: Determines how many batches to process in
parallel when using map and batch from tf.data.
dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
"dynamic". Only valid if `dtype` is True.
Returns:
A list of flags for core.py to marks as key flags.
...
...
@@ -117,24 +122,46 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
"Variables may be cast to a higher precision on a "
"case-by-case basis for numerical stability."
))
flags
.
DEFINE_integer
(
name
=
"loss_scale"
,
short_name
=
"ls"
,
default
=
None
,
help
=
help_wrap
(
"The amount to scale the loss by when the model is run. Before "
loss_scale_help_text
=
(
"The amount to scale the loss by when the model is run. {}. Before "
"gradients are computed, the loss is multiplied by the loss scale, "
"making all gradients loss_scale times larger. To adjust for this, "
"gradients are divided by the loss scale before being applied to "
"variables. This is mathematically equivalent to training without "
"a loss scale, but the loss scale helps avoid some intermediate "
"gradients from underflowing to zero. If not provided the default "
"for fp16 is 128 and 1 for all other dtypes."
))
"for fp16 is 128 and 1 for all other dtypes.{}"
)
if
dynamic_loss_scale
:
loss_scale_help_text
=
loss_scale_help_text
.
format
(
"This can be an int/float or the string 'dynamic'"
,
" The string 'dynamic' can be used to dynamically determine the "
"optimal loss scale during training, but currently this "
"significantly slows down performance"
)
loss_scale_validation_msg
=
(
"loss_scale should be a positive int/float "
"or the string 'dynamic'."
)
else
:
loss_scale_help_text
=
loss_scale_help_text
.
format
(
"This must be an int/float"
,
""
)
loss_scale_validation_msg
=
"loss_scale should be a positive int/float."
flags
.
DEFINE_string
(
name
=
"loss_scale"
,
short_name
=
"ls"
,
default
=
None
,
help
=
help_wrap
(
loss_scale_help_text
))
loss_scale_val_msg
=
"loss_scale should be a positive integer."
@
flags
.
validator
(
flag_name
=
"loss_scale"
,
message
=
loss_scale_val_msg
)
@
flags
.
validator
(
flag_name
=
"loss_scale"
,
message
=
loss_scale_validation_msg
)
def
_check_loss_scale
(
loss_scale
):
# pylint: disable=unused-variable
"""Validator to check the loss scale flag is valid"""
if
loss_scale
is
None
:
return
True
# null case is handled in get_loss_scale()
if
loss_scale
==
"dynamic"
and
dynamic_loss_scale
:
return
True
try
:
loss_scale
=
float
(
loss_scale
)
except
ValueError
:
return
False
return
loss_scale
>
0
if
all_reduce_alg
:
...
...
official/utils/flags/flags_test.py
View file @
17e923da
...
...
@@ -23,7 +23,7 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def
define_flags
():
flags_core
.
define_base
(
num_gpu
=
False
)
flags_core
.
define_performance
()
flags_core
.
define_performance
(
dynamic_loss_scale
=
True
)
flags_core
.
define_image
()
flags_core
.
define_benchmark
()
...
...
@@ -89,12 +89,20 @@ class BaseTester(unittest.TestCase):
flags_core
.
parse_flags
(
[
__file__
,
"--dtype"
,
dtype_str
,
"--loss_scale"
,
"5"
])
self
.
assertEqual
(
flags_core
.
get_loss_scale
(
flags
.
FLAGS
),
5
)
flags_core
.
parse_flags
(
[
__file__
,
"--dtype"
,
dtype_str
,
"--loss_scale"
,
"dynamic"
])
self
.
assertEqual
(
flags_core
.
get_loss_scale
(
flags
.
FLAGS
),
"dynamic"
)
with
self
.
assertRaises
(
SystemExit
):
flags_core
.
parse_flags
([
__file__
,
"--dtype"
,
"int8"
])
with
self
.
assertRaises
(
SystemExit
):
flags_core
.
parse_flags
([
__file__
,
"--dtype"
,
"fp16"
,
"--loss_scale"
,
"abc"
])
if
__name__
==
"__main__"
:
unittest
.
main
()
official/utils/flags/guidelines.md
View file @
17e923da
...
...
@@ -47,6 +47,8 @@
def get_loss_scale(flags_obj):
if flags_obj.loss_scale == "dynamic":
return flags_obj.loss_scale
if flags_obj.loss_scale is not None:
return flags_obj.loss_scale
return DTYPE_MAP[flags_obj.dtype][1]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment