Commit f28059c0 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Fix for resnet data_dir in the keras_imagenet_benchmark suite.

PiperOrigin-RevId: 270122397
parent 0a79175b
...@@ -46,8 +46,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -46,8 +46,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags] flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
self.data_dir = ("/readahead/200M/placer/prod/home/distbelief/" self.data_dir = os.path.join(root_data_dir, 'imagenet')
"imagenet-tensorflow/imagenet-2012-tfrecord")
super(Resnet50KerasAccuracy, self).__init__( super(Resnet50KerasAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods) output_dir=output_dir, flag_methods=flag_methods)
...@@ -838,8 +837,7 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasBenchmarkBase): ...@@ -838,8 +837,7 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasBenchmarkBase):
def_flags = {} def_flags = {}
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False def_flags['report_accuracy_metrics'] = False
def_flags['data_dir'] = ("/readahead/200M/placer/prod/home/distbelief/" def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
"imagenet-tensorflow/imagenet-2012-tfrecord")
def_flags['train_steps'] = 110 def_flags['train_steps'] = 110
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
...@@ -859,8 +857,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark): ...@@ -859,8 +857,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
def_flags['report_accuracy_metrics'] = False def_flags['report_accuracy_metrics'] = False
def_flags['use_tensor_lr'] = True def_flags['use_tensor_lr'] = True
def_flags['dtype'] = 'fp16' def_flags['dtype'] = 'fp16'
def_flags['data_dir'] = ("/readahead/200M/placer/prod/home/distbelief/" def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
"imagenet-tensorflow/imagenet-2012-tfrecord")
def_flags['train_steps'] = 600 def_flags['train_steps'] = 600
def_flags['log_steps'] = 100 def_flags['log_steps'] = 100
def_flags['distribution_strategy'] = 'default' def_flags['distribution_strategy'] = 'default'
...@@ -974,8 +971,7 @@ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -974,8 +971,7 @@ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark):
def __init__(self, output_dir=None, root_data_dir=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags] flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
self.data_dir = ("/readahead/200M/placer/prod/home/distbelief/" self.data_dir = os.path.join(root_data_dir, 'imagenet')
"imagenet-tensorflow/imagenet-2012-tfrecord")
super(Resnet50MultiWorkerKerasAccuracy, self).__init__( super(Resnet50MultiWorkerKerasAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods) output_dir=output_dir, flag_methods=flag_methods)
...@@ -1153,8 +1149,7 @@ class Resnet50MultiWorkerKerasBenchmarkReal(Resnet50MultiWorkerKerasBenchmark): ...@@ -1153,8 +1149,7 @@ class Resnet50MultiWorkerKerasBenchmarkReal(Resnet50MultiWorkerKerasBenchmark):
def_flags = {} def_flags = {}
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False def_flags['report_accuracy_metrics'] = False
def_flags['data_dir'] = ("/readahead/200M/placer/prod/home/distbelief/" def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
"imagenet-tensorflow/imagenet-2012-tfrecord")
def_flags['train_steps'] = 110 def_flags['train_steps'] = 110
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment