Commit 5f4d34fc authored by Dong Lin's avatar Dong Lin Committed by Toby Boyd
Browse files

Add kwargs to make the benchmark class constructor forward compatible. (#6246)

This is needed to avoid breaking benchmark execution if PerfZero provides more
Named arguments before  the benchmark class constructor is updated.
parent 5c6fa148
......@@ -35,12 +35,15 @@ class EstimatorCifar10BenchmarkTests(tf.test.Benchmark):
local_flags = None
def __init__(self, output_dir=None, root_data_dir=None):
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
"""A benchmark class.
Args:
output_dir: directory where to output e.g. log files
root_data_dir: directory under which to look for dataset
**kwargs: arbitrary named arguments. This is needed to make the
constructor forward compatible in case PerfZero provides more
named arguments before updating the constructor.
"""
self.output_dir = output_dir
......
......@@ -36,12 +36,15 @@ FLAGS = flags.FLAGS
class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
"""Accuracy tests for ResNet56 Keras CIFAR-10."""
def __init__(self, output_dir=None, root_data_dir=None):
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
"""A benchmark class.
Args:
output_dir: directory where to output e.g. log files
root_data_dir: directory under which to look for dataset
**kwargs: arbitrary named arguments. This is needed to make the
constructor forward compatible in case PerfZero provides more
named arguments before updating the constructor.
"""
self.data_dir = os.path.join(root_data_dir, 'cifar-10-batches-bin')
......@@ -206,7 +209,7 @@ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase):
"""Synthetic benchmarks for ResNet56 and Keras."""
def __init__(self, output_dir=None, root_data_dir=None):
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['use_synthetic_data'] = True
......@@ -220,7 +223,7 @@ class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase):
class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase):
"""Real data benchmarks for ResNet56 and Keras."""
def __init__(self, output_dir=None, root_data_dir=None):
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['data_dir'] = self.data_dir
......
......@@ -34,12 +34,15 @@ FLAGS = flags.FLAGS
class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
"""Benchmark accuracy tests for ResNet50 in Keras."""
def __init__(self, output_dir=None, root_data_dir=None):
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
"""A benchmark class.
Args:
output_dir: directory where to output e.g. log files
root_data_dir: directory under which to look for dataset
**kwargs: arbitrary named arguments. This is needed to make the
constructor forward compatible in case PerfZero provides more
named arguments before updating the constructor.
"""
flag_methods = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment