Commit f9491103 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Add test for performance exporter

PiperOrigin-RevId: 364363163
parent 0b7674b9
......@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import tempfile
import unittest
......@@ -26,9 +27,9 @@ import six
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
from object_detection import exporter_lib_v2
from object_detection import inputs
from object_detection import model_lib_v2
from object_detection.builders import model_builder
from object_detection.core import model
from object_detection.protos import train_pb2
from object_detection.utils import config_util
......@@ -145,6 +146,12 @@ class SimpleModel(model.DetectionModel):
return []
def fake_model_builder(*_, **__):
return SimpleModel()
FAKE_BUILDER_MAP = {'build': fake_model_builder}
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class ModelCheckpointTest(tf.test.TestCase):
"""Test for model checkpoint related functionality."""
......@@ -153,10 +160,9 @@ class ModelCheckpointTest(tf.test.TestCase):
"""Test that only the most recent checkpoints are kept."""
strategy = tf2.distribute.OneDeviceStrategy(device='/cpu:0')
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
with strategy.scope():
mock_builder.return_value = SimpleModel()
with mock.patch.dict(
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP, FAKE_BUILDER_MAP):
model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
new_pipeline_config_path = os.path.join(model_dir, 'new_pipeline.config')
......@@ -226,5 +232,40 @@ class CheckpointV2Test(tf.test.TestCase):
unpad_groundtruth_tensors=True)
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class MetricsExportTest(tf.test.TestCase):
@classmethod
def setUpClass(cls): # pylint:disable=g-missing-super-call
tf.keras.backend.clear_session()
def test_export_metrics_json_serializable(self):
"""Tests that Estimator and input function are constructed correctly."""
strategy = tf2.distribute.OneDeviceStrategy(device='/cpu:0')
def export(data, _):
json.dumps(data)
with mock.patch.dict(
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP, FAKE_BUILDER_MAP):
with strategy.scope():
model_dir = tf.test.get_temp_dir()
new_pipeline_config_path = os.path.join(model_dir,
'new_pipeline.config')
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
config_util.clear_fine_tune_checkpoint(pipeline_config_path,
new_pipeline_config_path)
train_steps = 2
with strategy.scope():
model_lib_v2.train_loop(
new_pipeline_config_path,
model_dir=model_dir,
train_steps=train_steps,
checkpoint_every_n=100,
performance_summary_exporter=export,
**_get_config_kwarg_overrides())
if __name__ == '__main__':
tf.test.main()
......@@ -686,7 +686,7 @@ def train_loop(
'steps_per_sec': np.mean(steps_per_sec_list),
'steps_per_sec_p50': np.median(steps_per_sec_list),
'steps_per_sec_max': max(steps_per_sec_list),
'last_batch_loss': loss
'last_batch_loss': float(loss)
}
mixed_precision = 'bf16' if kwargs['use_bfloat16'] else 'fp32'
performance_summary_exporter(metrics, mixed_precision)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment