Unverified Commit cc98c0aa authored by Reed's avatar Reed Committed by GitHub
Browse files

Fix dynamic loss scaling crash in benchmarks (#6532)

parent 3f94db4e
......@@ -46,7 +46,8 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
"""
flag_methods = [
keras_common.define_keras_flags, imagenet_main.define_imagenet_flags
keras_common.define_keras_flags,
lambda: imagenet_main.define_imagenet_flags(dynamic_loss_scale=True)
]
self.data_dir = os.path.join(root_data_dir, 'imagenet')
......@@ -145,7 +146,8 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
def __init__(self, output_dir=None, default_flags=None):
flag_methods = [
keras_common.define_keras_flags, imagenet_main.define_imagenet_flags
keras_common.define_keras_flags,
lambda: imagenet_main.define_imagenet_flags(dynamic_loss_scale=True)
]
super(Resnet50KerasBenchmarkBase, self).__init__(
......@@ -527,7 +529,8 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
flag_methods = [
keras_common.define_keras_flags, imagenet_main.define_imagenet_flags
keras_common.define_keras_flags,
lambda: imagenet_main.define_imagenet_flags(dynamic_loss_scale=True)
]
def_flags = {}
def_flags['skip_eval'] = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment