Commit a0d4ac79 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Update classifier_trainer to allow metrics disabling for benchmarks.

PiperOrigin-RevId: 313321531
parent 9cd1c1d5
...@@ -62,6 +62,7 @@ def _get_classifier_parameters( ...@@ -62,6 +62,7 @@ def _get_classifier_parameters(
gpu_thread_mode: Optional[str] = None, gpu_thread_mode: Optional[str] = None,
dataset_num_private_threads: Optional[int] = None, dataset_num_private_threads: Optional[int] = None,
loss_scale: Optional[str] = None, loss_scale: Optional[str] = None,
report_metrics: bool = True,
batchnorm_spatial_persistent: bool = False) -> MutableMapping[str, Any]: batchnorm_spatial_persistent: bool = False) -> MutableMapping[str, Any]:
"""Gets classifier trainer's ResNet parameters.""" """Gets classifier trainer's ResNet parameters."""
return { return {
...@@ -97,6 +98,7 @@ def _get_classifier_parameters( ...@@ -97,6 +98,7 @@ def _get_classifier_parameters(
'enable_checkpoint_and_export': False, 'enable_checkpoint_and_export': False,
'enable_time_history': True, 'enable_time_history': True,
}, },
'metrics': ['accuracy'] if report_metrics else [],
}, },
'model': { 'model': {
'loss': { 'loss': {
...@@ -169,6 +171,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -169,6 +171,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
run_eagerly=run_eagerly, run_eagerly=run_eagerly,
gpu_thread_mode=gpu_thread_mode, gpu_thread_mode=gpu_thread_mode,
dataset_num_private_threads=dataset_num_private_threads, dataset_num_private_threads=dataset_num_private_threads,
report_metrics=True,
loss_scale=loss_scale, loss_scale=loss_scale,
batchnorm_spatial_persistent=True) batchnorm_spatial_persistent=True)
FLAGS.params_override = json.dumps(parameters) FLAGS.params_override = json.dumps(parameters)
...@@ -353,6 +356,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -353,6 +356,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
gpu_thread_mode=gpu_thread_mode, gpu_thread_mode=gpu_thread_mode,
dataset_num_private_threads=dataset_num_private_threads, dataset_num_private_threads=dataset_num_private_threads,
loss_scale=loss_scale, loss_scale=loss_scale,
report_metrics=False,
batchnorm_spatial_persistent=True) batchnorm_spatial_persistent=True)
FLAGS.params_override = json.dumps(parameters) FLAGS.params_override = json.dumps(parameters)
if distribution_strategy == 'tpu': if distribution_strategy == 'tpu':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment