Commit 31e86e86 authored by Pankaj Kanwar's avatar Pankaj Kanwar Committed by TF Object Detection Team
Browse files

Export metrics to MLCompass

PiperOrigin-RevId: 343920343
parent 2abc27d0
......@@ -21,6 +21,7 @@ from __future__ import print_function
import copy
import os
import time
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
......@@ -419,6 +420,7 @@ def train_loop(
checkpoint_every_n=1000,
checkpoint_max_to_keep=7,
record_summaries=True,
performance_summary_exporter=None,
**kwargs):
"""Trains a model using eager + functions.
......@@ -449,6 +451,7 @@ def train_loop(
checkpoint_max_to_keep:
int, the number of most recent checkpoints to keep in the model directory.
record_summaries: Boolean, whether or not to record summaries.
performance_summary_exporter: function for exporting performance metrics.
**kwargs: Additional keyword arguments for configuration override.
"""
## Parse the configs
......@@ -458,6 +461,7 @@ def train_loop(
'merge_external_params_with_configs']
create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
'create_pipeline_proto_from_configs']
steps_per_sec_list = []
configs = get_configs_from_pipeline_file(
pipeline_config_path, config_override=config_override)
......@@ -633,10 +637,12 @@ def train_loop(
time_taken = time.time() - last_step_time
last_step_time = time.time()
steps_per_sec = num_steps_per_iteration * 1.0 / time_taken
tf.compat.v2.summary.scalar(
'steps_per_sec', num_steps_per_iteration * 1.0 / time_taken,
step=global_step)
'steps_per_sec', steps_per_sec, step=global_step)
steps_per_sec_list.append(steps_per_sec)
if global_step.value() - logged_step >= 100:
tf.logging.info(
......@@ -655,6 +661,15 @@ def train_loop(
# training.
clean_temporary_directories(strategy, manager_dir)
clean_temporary_directories(strategy, summary_writer_filepath)
# TODO(pkanwar): add accuracy metrics.
if performance_summary_exporter is not None:
metrics = {
'steps_per_sec': np.mean(steps_per_sec_list),
'steps_per_sec_p50': np.median(steps_per_sec_list),
'steps_per_sec_max': max(steps_per_sec_list),
}
mixed_precision = 'bf16' if kwargs['use_bfloat16'] else 'fp32'
performance_summary_exporter(metrics, mixed_precision)
def prepare_eval_dict(detections, groundtruth, features):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment