Unverified Commit da1d3e60 authored by Dong Lin's avatar Dong Lin Committed by GitHub
Browse files

Set data_dir to cifar-10-batches-bin in keras_cifar_benchmark.py (#6251)

parent 21a4ad75
...@@ -31,6 +31,7 @@ MIN_TOP_1_ACCURACY = 0.925 ...@@ -31,6 +31,7 @@ MIN_TOP_1_ACCURACY = 0.925
MAX_TOP_1_ACCURACY = 0.938 MAX_TOP_1_ACCURACY = 0.938
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
CIFAR_DATA_DIR_NAME = 'cifar-10-batches-bin'
class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark): class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
...@@ -47,7 +48,7 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -47,7 +48,7 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
named arguments before updating the constructor. named arguments before updating the constructor.
""" """
self.data_dir = os.path.join(root_data_dir, 'cifar-10-batches-bin') self.data_dir = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
flag_methods = [ flag_methods = [
keras_common.define_keras_flags, cifar_main.define_cifar_flags keras_common.define_keras_flags, cifar_main.define_cifar_flags
] ]
...@@ -210,25 +211,25 @@ class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase): ...@@ -210,25 +211,25 @@ class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase):
"""Synthetic benchmarks for ResNet56 and Keras.""" """Synthetic benchmarks for ResNet56 and Keras."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {} default_flags = {}
def_flags['skip_eval'] = True default_flags['skip_eval'] = True
def_flags['use_synthetic_data'] = True default_flags['use_synthetic_data'] = True
def_flags['train_steps'] = 110 default_flags['train_steps'] = 110
def_flags['log_steps'] = 10 default_flags['log_steps'] = 10
super(Resnet56KerasBenchmarkSynth, self).__init__( super(Resnet56KerasBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags) output_dir=output_dir, default_flags=default_flags)
class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase): class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase):
"""Real data benchmarks for ResNet56 and Keras.""" """Real data benchmarks for ResNet56 and Keras."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {} default_flags = {}
def_flags['skip_eval'] = True default_flags['skip_eval'] = True
def_flags['data_dir'] = self.data_dir default_flags['data_dir'] = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
def_flags['train_steps'] = 110 default_flags['train_steps'] = 110
def_flags['log_steps'] = 10 default_flags['log_steps'] = 10
super(Resnet56KerasBenchmarkReal, self).__init__( super(Resnet56KerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags) output_dir=output_dir, default_flags=default_flags)
...@@ -555,7 +555,7 @@ def resnet_main( ...@@ -555,7 +555,7 @@ def resnet_main(
flags_obj.train_epochs) flags_obj.train_epochs)
use_train_and_evaluate = flags_obj.use_train_and_evaluate or ( use_train_and_evaluate = flags_obj.use_train_and_evaluate or (
distribution_strategy.__class__.__name__ == 'CollectiveAllReduceStrategy') distribution_strategy.__class__.__name__ == 'CollectiveAllReduceStrategy')
if use_train_and_evaluate: if use_train_and_evaluate:
train_spec = tf.estimator.TrainSpec( train_spec = tf.estimator.TrainSpec(
input_fn=lambda: input_fn_train(train_epochs), hooks=train_hooks, input_fn=lambda: input_fn_train(train_epochs), hooks=train_hooks,
...@@ -588,6 +588,10 @@ def resnet_main( ...@@ -588,6 +588,10 @@ def resnet_main(
int(n_loops)) int(n_loops))
if num_train_epochs: if num_train_epochs:
# Since we are calling classifier.train immediately in each loop, the
# value of num_train_epochs in the lambda function will not be changed
# before it is used. So it is safe to ignore the pylint error here
# pylint: disable=cell-var-from-loop
classifier.train(input_fn=lambda: input_fn_train(num_train_epochs), classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
hooks=train_hooks, max_steps=flags_obj.max_train_steps) hooks=train_hooks, max_steps=flags_obj.max_train_steps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment