Unverified Commit da1d3e60 authored by Dong Lin's avatar Dong Lin Committed by GitHub
Browse files

Set data_dir to cifar-10-batches-bin in keras_cifar_benchmark.py (#6251)

parent 21a4ad75
......@@ -31,6 +31,7 @@ MIN_TOP_1_ACCURACY = 0.925
MAX_TOP_1_ACCURACY = 0.938
FLAGS = flags.FLAGS
CIFAR_DATA_DIR_NAME = 'cifar-10-batches-bin'
class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
......@@ -47,7 +48,7 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
named arguments before updating the constructor.
"""
self.data_dir = os.path.join(root_data_dir, 'cifar-10-batches-bin')
self.data_dir = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
flag_methods = [
keras_common.define_keras_flags, cifar_main.define_cifar_flags
]
......@@ -210,25 +211,25 @@ class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase):
"""Synthetic benchmarks for ResNet56 and Keras."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['use_synthetic_data'] = True
def_flags['train_steps'] = 110
def_flags['log_steps'] = 10
default_flags = {}
default_flags['skip_eval'] = True
default_flags['use_synthetic_data'] = True
default_flags['train_steps'] = 110
default_flags['log_steps'] = 10
super(Resnet56KerasBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags)
output_dir=output_dir, default_flags=default_flags)
class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase):
"""Real data benchmarks for ResNet56 and Keras."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['data_dir'] = self.data_dir
def_flags['train_steps'] = 110
def_flags['log_steps'] = 10
default_flags = {}
default_flags['skip_eval'] = True
default_flags['data_dir'] = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
default_flags['train_steps'] = 110
default_flags['log_steps'] = 10
super(Resnet56KerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags)
output_dir=output_dir, default_flags=default_flags)
......@@ -555,7 +555,7 @@ def resnet_main(
flags_obj.train_epochs)
use_train_and_evaluate = flags_obj.use_train_and_evaluate or (
distribution_strategy.__class__.__name__ == 'CollectiveAllReduceStrategy')
distribution_strategy.__class__.__name__ == 'CollectiveAllReduceStrategy')
if use_train_and_evaluate:
train_spec = tf.estimator.TrainSpec(
input_fn=lambda: input_fn_train(train_epochs), hooks=train_hooks,
......@@ -588,6 +588,10 @@ def resnet_main(
int(n_loops))
if num_train_epochs:
# Since we are calling classifier.train immediately in each loop, the
# value of num_train_epochs in the lambda function will not be changed
# before it is used. So it is safe to ignore the pylint error here
# pylint: disable=cell-var-from-loop
classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
hooks=train_hooks, max_steps=flags_obj.max_train_steps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment