"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "83b1e42f3012da2f3674b118e83d9d33d9aba633"
Commit a0d4ac79 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Update classifier_trainer to allow metrics disabling for benchmarks.

PiperOrigin-RevId: 313321531
parent 9cd1c1d5
......@@ -62,6 +62,7 @@ def _get_classifier_parameters(
gpu_thread_mode: Optional[str] = None,
dataset_num_private_threads: Optional[int] = None,
loss_scale: Optional[str] = None,
report_metrics: bool = True,
batchnorm_spatial_persistent: bool = False) -> MutableMapping[str, Any]:
"""Gets classifier trainer's ResNet parameters."""
return {
......@@ -97,6 +98,7 @@ def _get_classifier_parameters(
'enable_checkpoint_and_export': False,
'enable_time_history': True,
},
'metrics': ['accuracy'] if report_metrics else [],
},
'model': {
'loss': {
......@@ -169,6 +171,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
run_eagerly=run_eagerly,
gpu_thread_mode=gpu_thread_mode,
dataset_num_private_threads=dataset_num_private_threads,
report_metrics=True,
loss_scale=loss_scale,
batchnorm_spatial_persistent=True)
FLAGS.params_override = json.dumps(parameters)
......@@ -353,6 +356,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
gpu_thread_mode=gpu_thread_mode,
dataset_num_private_threads=dataset_num_private_threads,
loss_scale=loss_scale,
report_metrics=False,
batchnorm_spatial_persistent=True)
FLAGS.params_override = json.dumps(parameters)
if distribution_strategy == 'tpu':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment