Commit 0a79175b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Fix for resnet data_dir name in copybara.

PiperOrigin-RevId: 270079001
parent 32dadfc2
......@@ -46,7 +46,8 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
self.data_dir = os.path.join(root_data_dir, 'imagenet')
self.data_dir = ("/readahead/200M/placer/prod/home/distbelief/"
"imagenet-tensorflow/imagenet-2012-tfrecord")
super(Resnet50KerasAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods)
......@@ -837,7 +838,8 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasBenchmarkBase):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['data_dir'] = ("/readahead/200M/placer/prod/home/distbelief/"
"imagenet-tensorflow/imagenet-2012-tfrecord")
def_flags['train_steps'] = 110
def_flags['log_steps'] = 10
......@@ -857,7 +859,8 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
def_flags['report_accuracy_metrics'] = False
def_flags['use_tensor_lr'] = True
def_flags['dtype'] = 'fp16'
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['data_dir'] = ("/readahead/200M/placer/prod/home/distbelief/"
"imagenet-tensorflow/imagenet-2012-tfrecord")
def_flags['train_steps'] = 600
def_flags['log_steps'] = 100
def_flags['distribution_strategy'] = 'default'
......@@ -971,7 +974,8 @@ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark):
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
self.data_dir = os.path.join(root_data_dir, 'imagenet')
self.data_dir = ("/readahead/200M/placer/prod/home/distbelief/"
"imagenet-tensorflow/imagenet-2012-tfrecord")
super(Resnet50MultiWorkerKerasAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods)
......@@ -1149,7 +1153,8 @@ class Resnet50MultiWorkerKerasBenchmarkReal(Resnet50MultiWorkerKerasBenchmark):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['data_dir'] = ("/readahead/200M/placer/prod/home/distbelief/"
"imagenet-tensorflow/imagenet-2012-tfrecord")
def_flags['train_steps'] = 110
def_flags['log_steps'] = 10
......
......@@ -123,8 +123,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
flag_methods = [ctl_common.define_ctl_flags, common.define_keras_flags]
self.data_dir = ('/readahead/200M/placer/prod/home/distbelief/'
'imagenet-tensorflow/imagenet-2012-tfrecord')
self.data_dir = os.path.join(root_data_dir, 'imagenet')
super(Resnet50CtlAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment