Commit 8e1e0339 authored by Ayush Dubey's avatar Ayush Dubey Committed by A. Unique TensorFlower
Browse files

Add a multi-worker real data Keras ResNet benchmark.

PiperOrigin-RevId: 264511778
parent 8856b918
......@@ -973,7 +973,7 @@ class Resnet50MultiWorkerKerasBenchmark(Resnet50KerasBenchmarkBase):
class Resnet50MultiWorkerKerasBenchmarkSynth(Resnet50MultiWorkerKerasBenchmark):
"""Resnet50 multi-worker synthetic benchmark tests."""
"""Resnet50 multi-worker synthetic data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {}
......@@ -987,5 +987,20 @@ class Resnet50MultiWorkerKerasBenchmarkSynth(Resnet50MultiWorkerKerasBenchmark):
output_dir=output_dir, default_flags=def_flags)
class Resnet50MultiWorkerKerasBenchmarkReal(Resnet50MultiWorkerKerasBenchmark):
"""Resnet50 multi-worker real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['train_steps'] = 110
def_flags['log_steps'] = 10
super(Resnet50MultiWorkerKerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment