Unverified Commit 4b8fe704 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Forbid ResNet v1 from running with fp16 (#4207)

* forbid resnet v1 fp16

* address PR comments
parent 911a0d23
...@@ -164,6 +164,13 @@ class BaseTest(tf.test.TestCase): ...@@ -164,6 +164,13 @@ class BaseTest(tf.test.TestCase):
extra_flags=['-resnet_version', '2'] extra_flags=['-resnet_version', '2']
) )
def test_flag_restriction(self):
with self.assertRaises(SystemExit):
integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', "-dtype", "fp16"]
)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -303,6 +303,13 @@ class BaseTest(tf.test.TestCase): ...@@ -303,6 +303,13 @@ class BaseTest(tf.test.TestCase):
extra_flags=['-resnet_version', '2', '-resnet_size', '200'] extra_flags=['-resnet_version', '2', '-resnet_size', '200']
) )
def test_flag_restriction(self):
with self.assertRaises(SystemExit):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-dtype', 'fp16']
)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -25,8 +25,9 @@ from __future__ import print_function ...@@ -25,8 +25,9 @@ from __future__ import print_function
import os import os
# pylint: disable=g-bad-import-order
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf
from official.resnet import resnet_model from official.resnet import resnet_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
...@@ -34,6 +35,7 @@ from official.utils.export import export ...@@ -34,6 +35,7 @@ from official.utils.export import export
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
# pylint: enable=g-bad-import-order
################################################################################ ################################################################################
...@@ -462,7 +464,6 @@ def define_resnet_flags(resnet_size_choices=None): ...@@ -462,7 +464,6 @@ def define_resnet_flags(resnet_size_choices=None):
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Version of ResNet. (1 or 2) See README.md for details.')) 'Version of ResNet. (1 or 2) See README.md for details.'))
choice_kwargs = dict( choice_kwargs = dict(
name='resnet_size', short_name='rs', default='50', name='resnet_size', short_name='rs', default='50',
help=flags_core.help_wrap('The size of the ResNet model to use.')) help=flags_core.help_wrap('The size of the ResNet model to use.'))
...@@ -471,3 +472,12 @@ def define_resnet_flags(resnet_size_choices=None): ...@@ -471,3 +472,12 @@ def define_resnet_flags(resnet_size_choices=None):
flags.DEFINE_string(**choice_kwargs) flags.DEFINE_string(**choice_kwargs)
else: else:
flags.DEFINE_enum(enum_values=resnet_size_choices, **choice_kwargs) flags.DEFINE_enum(enum_values=resnet_size_choices, **choice_kwargs)
# The current implementation of ResNet v1 is numerically unstable when run
# with fp16 and will produce NaN errors soon after training begins.
msg = ('ResNet version 1 is not currently supported with fp16. '
'Please use version 2 instead.')
@flags.multi_flags_validator(['dtype', 'resnet_version'], message=msg)
def _forbid_v1_fp16(flag_values): # pylint: disable=unused-variable
return (flags_core.DTYPE_MAP[flag_values['dtype']][0] != tf.float16 or
flag_values['resnet_version'] != '1')
...@@ -82,3 +82,4 @@ help_wrap = _conventions.help_wrap ...@@ -82,3 +82,4 @@ help_wrap = _conventions.help_wrap
get_num_gpus = _base.get_num_gpus get_num_gpus = _base.get_num_gpus
get_tf_dtype = _performance.get_tf_dtype get_tf_dtype = _performance.get_tf_dtype
get_loss_scale = _performance.get_loss_scale get_loss_scale = _performance.get_loss_scale
DTYPE_MAP = _performance.DTYPE_MAP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment