"...git@developer.sourcefind.cn:OpenDAS/llama-factory.git" did not exist on "60eb395cca20da96914d303a2d0c164a2036c1d7"
Unverified Commit cfa37aab authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Do not report accuracy metrics for benchmark tests (#6757)

* Do not report metrics in performance benchmarks

* Rename flag
parent bae940dc
...@@ -333,6 +333,8 @@ def define_keras_flags(): ...@@ -333,6 +333,8 @@ def define_keras_flags():
flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?') flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
flags.DEFINE_boolean(name='use_trivial_model', default=False, flags.DEFINE_boolean(name='use_trivial_model', default=False,
help='Whether to use a trivial Keras model.') help='Whether to use a trivial Keras model.')
flags.DEFINE_boolean(name='report_accuracy_metrics', default=True,
help='Report metrics during training and evaluation.')
flags.DEFINE_boolean(name='use_tensor_lr', default=False, flags.DEFINE_boolean(name='use_tensor_lr', default=False,
help='Use learning rate tensor instead of a callback.') help='Use learning rate tensor instead of a callback.')
flags.DEFINE_boolean( flags.DEFINE_boolean(
......
...@@ -650,6 +650,7 @@ class Resnet50KerasBenchmarkSynth(Resnet50KerasBenchmarkBase): ...@@ -650,6 +650,7 @@ class Resnet50KerasBenchmarkSynth(Resnet50KerasBenchmarkBase):
def __init__(self, output_dir=None, root_data_dir=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {} def_flags = {}
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False
def_flags['use_synthetic_data'] = True def_flags['use_synthetic_data'] = True
def_flags['train_steps'] = 110 def_flags['train_steps'] = 110
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
...@@ -664,6 +665,7 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasBenchmarkBase): ...@@ -664,6 +665,7 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasBenchmarkBase):
def __init__(self, output_dir=None, root_data_dir=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {} def_flags = {}
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet') def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['train_steps'] = 110 def_flags['train_steps'] = 110
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
...@@ -682,6 +684,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark): ...@@ -682,6 +684,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
] ]
def_flags = {} def_flags = {}
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False
def_flags['dtype'] = 'fp16' def_flags['dtype'] = 'fp16'
def_flags['enable_xla'] = True def_flags['enable_xla'] = True
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet') def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
......
...@@ -210,7 +210,8 @@ def run(flags_obj): ...@@ -210,7 +210,8 @@ def run(flags_obj):
model.compile(loss='sparse_categorical_crossentropy', model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer, optimizer=optimizer,
metrics=['sparse_categorical_accuracy'], metrics=(['sparse_categorical_accuracy']
if flags_obj.report_accuracy_metrics else None),
cloning=flags_obj.clone_model_in_keras_dist_strat) cloning=flags_obj.clone_model_in_keras_dist_strat)
callbacks = keras_common.get_callbacks( callbacks = keras_common.get_callbacks(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment