Commit ba39a3db authored by Toby Boyd's avatar Toby Boyd
Browse files

Change test to use run() to match API change.

parent 80dcd27c
......@@ -9,6 +9,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main
import official.resnet.keras.keras_cifar_main as keras_cifar_main
import official.resnet.keras.keras_common as keras_common
DATA_DIR = '/data/cifar10_data/'
......@@ -32,7 +33,7 @@ class KerasCifar10BenchmarkTests(object):
flags.FLAGS.model_dir = self._get_model_dir('keras_resnet56_1_gpu')
flags.FLAGS.resnet_size = 56
flags.FLAGS.dtype = 'fp32'
stats = keras_cifar_main.run_cifar_with_keras(flags.FLAGS)
stats = keras_cifar_main.run(flags.FLAGS)
self._fill_report_object(stats)
def keras_resnet56_4_gpu(self):
......@@ -45,7 +46,7 @@ class KerasCifar10BenchmarkTests(object):
flags.FLAGS.model_dir = ''
flags.FLAGS.resnet_size = 56
flags.FLAGS.dtype = 'fp32'
stats = keras_cifar_main.run_cifar_with_keras(flags.FLAGS)
stats = keras_cifar_main.run(flags.FLAGS)
self._fill_report_object(stats)
def keras_resnet56_no_dist_strat_1_gpu(self):
......@@ -60,7 +61,7 @@ class KerasCifar10BenchmarkTests(object):
'keras_resnet56_no_dist_strat_1_gpu')
flags.FLAGS.resnet_size = 56
flags.FLAGS.dtype = 'fp32'
stats = keras_cifar_main.run_cifar_with_keras(flags.FLAGS)
stats = keras_cifar_main.run(flags.FLAGS)
self._fill_report_object(stats)
def _fill_report_object(self, stats):
......@@ -76,17 +77,14 @@ class KerasCifar10BenchmarkTests(object):
return os.path.join(self.output_dir, folder_name)
def _setup(self):
"""Setups up and resets flags before each test."""
tf.logging.set_verbosity(tf.logging.DEBUG)
if KerasCifar10BenchmarkTests.local_flags is None:
print('Build Flags!!!!')
keras_cifar_main.define_keras_cifar_flags()
keras_common.define_keras_flags()
cifar_main.define_cifar_flags()
# Loads flags to get defaults to then override.
flags.FLAGS(['foo'])
saved_flag_values = flagsaver.save_flag_values()
KerasCifar10BenchmarkTests.local_flags = saved_flag_values
return
print('Restore Flags')
flagsaver.restore_flag_values(KerasCifar10BenchmarkTests.local_flags)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment