Unverified Commit cfa37aab authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Do not report accuracy metrics for benchmark tests (#6757)

* Do not report metrics in performance benchmarks

* Rename flag
parent bae940dc
......@@ -333,6 +333,8 @@ def define_keras_flags():
flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
flags.DEFINE_boolean(name='use_trivial_model', default=False,
help='Whether to use a trivial Keras model.')
flags.DEFINE_boolean(name='report_accuracy_metrics', default=True,
help='Report metrics during training and evaluation.')
flags.DEFINE_boolean(name='use_tensor_lr', default=False,
help='Use learning rate tensor instead of a callback.')
flags.DEFINE_boolean(
......
......@@ -650,6 +650,7 @@ class Resnet50KerasBenchmarkSynth(Resnet50KerasBenchmarkBase):
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False
def_flags['use_synthetic_data'] = True
def_flags['train_steps'] = 110
def_flags['log_steps'] = 10
......@@ -664,6 +665,7 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasBenchmarkBase):
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['train_steps'] = 110
def_flags['log_steps'] = 10
......@@ -682,6 +684,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
]
def_flags = {}
def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False
def_flags['dtype'] = 'fp16'
def_flags['enable_xla'] = True
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
......
......@@ -210,7 +210,8 @@ def run(flags_obj):
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
metrics=['sparse_categorical_accuracy'],
metrics=(['sparse_categorical_accuracy']
if flags_obj.report_accuracy_metrics else None),
cloning=flags_obj.clone_model_in_keras_dist_strat)
callbacks = keras_common.get_callbacks(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment