Commit ace53b59 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the export_tflite_graph_tf2.py script to support the CenterNet model

for object detection and keypoint estimation tasks.

PiperOrigin-RevId: 350640225
parent 59c4ccb2
...@@ -22,6 +22,7 @@ import tensorflow.compat.v2 as tf ...@@ -22,6 +22,7 @@ import tensorflow.compat.v2 as tf
from object_detection.builders import model_builder from object_detection.builders import model_builder
from object_detection.builders import post_processing_builder from object_detection.builders import post_processing_builder
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.core import standard_fields as fields
_DEFAULT_NUM_CHANNELS = 3 _DEFAULT_NUM_CHANNELS = 3
_DEFAULT_NUM_COORD_BOX = 4 _DEFAULT_NUM_COORD_BOX = 4
...@@ -198,8 +199,111 @@ class SSDModule(tf.Module): ...@@ -198,8 +199,111 @@ class SSDModule(tf.Module):
anchors)[::-1] anchors)[::-1]
class CenterNetModule(tf.Module):
"""Inference Module for TFLite-friendly CenterNet models.
The exported CenterNet model includes the preprocessing and postprocessing
logics so the caller should pass in the raw image pixel values. It supports
both object detection and keypoint estimation task.
"""
def __init__(self, pipeline_config, max_detections, include_keypoints):
"""Initialization.
Args:
pipeline_config: The original pipeline_pb2.TrainEvalPipelineConfig
max_detections: Max detections desired from the TFLite model.
include_keypoints: If set true, the output dictionary will include the
keypoint coordinates and keypoint confidence scores.
"""
self._max_detections = max_detections
self._include_keypoints = include_keypoints
self._process_config(pipeline_config)
self._pipeline_config = pipeline_config
self._model = model_builder.build(
self._pipeline_config.model, is_training=False)
def get_model(self):
return self._model
def _process_config(self, pipeline_config):
self._num_classes = pipeline_config.model.center_net.num_classes
center_net_config = pipeline_config.model.center_net
image_resizer_config = center_net_config.image_resizer
image_resizer = image_resizer_config.WhichOneof('image_resizer_oneof')
self._num_channels = _DEFAULT_NUM_CHANNELS
if image_resizer == 'fixed_shape_resizer':
self._height = image_resizer_config.fixed_shape_resizer.height
self._width = image_resizer_config.fixed_shape_resizer.width
if image_resizer_config.fixed_shape_resizer.convert_to_grayscale:
self._num_channels = 1
else:
raise ValueError(
'Only fixed_shape_resizer'
'is supported with tflite. Found {}'.format(image_resizer))
center_net_config.object_center_params.max_box_predictions = (
self._max_detections)
if not self._include_keypoints:
del center_net_config.keypoint_estimation_task[:]
def input_shape(self):
"""Returns shape of TFLite model input."""
return [1, self._height, self._width, self._num_channels]
@tf.function
def inference_fn(self, image):
"""Encapsulates CenterNet inference for TFLite conversion.
Args:
image: a float32 tensor of shape [1, image_height, image_width, channel]
denoting the image pixel values.
Returns:
A dictionary of predicted tensors:
classes: a float32 tensor with shape [1, max_detections] denoting class
ID for each detection.
scores: a float32 tensor with shape [1, max_detections] denoting score
for each detection.
boxes: a float32 tensor with shape [1, max_detections, 4] denoting
coordinates of each detected box.
keypoints: a float32 with shape [1, max_detections, num_keypoints, 2]
denoting the predicted keypoint coordinates (normalized in between
0-1). Note that [:, :, :, 0] represents the y coordinates and
[:, :, :, 1] represents the x coordinates.
keypoint_scores: a float32 with shape [1, max_detections, num_keypoints]
denoting keypoint confidence scores.
"""
image = tf.cast(image, tf.float32)
image, shapes = self._model.preprocess(image)
prediction_dict = self._model.predict(image, None)
detections = self._model.postprocess(
prediction_dict, true_image_shapes=shapes)
field_names = fields.DetectionResultFields
classes_field = field_names.detection_classes
classes = tf.cast(detections[classes_field], tf.float32)
num_detections = tf.cast(detections[field_names.num_detections], tf.float32)
if self._include_keypoints:
model_outputs = (detections[field_names.detection_boxes], classes,
detections[field_names.detection_scores], num_detections,
detections[field_names.detection_keypoints],
detections[field_names.detection_keypoint_scores])
else:
model_outputs = (detections[field_names.detection_boxes], classes,
detections[field_names.detection_scores], num_detections)
# tf.function@ seems to reverse order of inputs, so reverse them here.
return model_outputs[::-1]
def export_tflite_model(pipeline_config, trained_checkpoint_dir, def export_tflite_model(pipeline_config, trained_checkpoint_dir,
output_directory, max_detections, use_regular_nms): output_directory, max_detections, use_regular_nms,
include_keypoints=False):
"""Exports inference SavedModel for TFLite conversion. """Exports inference SavedModel for TFLite conversion.
NOTE: Only supports SSD meta-architectures for now, and the output model will NOTE: Only supports SSD meta-architectures for now, and the output model will
...@@ -215,6 +319,9 @@ def export_tflite_model(pipeline_config, trained_checkpoint_dir, ...@@ -215,6 +319,9 @@ def export_tflite_model(pipeline_config, trained_checkpoint_dir,
output_directory: Path to write outputs. output_directory: Path to write outputs.
max_detections: Max detections desired from the TFLite model. max_detections: Max detections desired from the TFLite model.
use_regular_nms: If True, TFLite model uses the (slower) multi-class NMS. use_regular_nms: If True, TFLite model uses the (slower) multi-class NMS.
Note that this argument is only used by the SSD model.
include_keypoints: Decides whether to also output the keypoint predictions.
Note that this argument is only used by the CenterNet model.
Raises: Raises:
ValueError: if pipeline is invalid. ValueError: if pipeline is invalid.
...@@ -223,22 +330,26 @@ def export_tflite_model(pipeline_config, trained_checkpoint_dir, ...@@ -223,22 +330,26 @@ def export_tflite_model(pipeline_config, trained_checkpoint_dir,
# Build the underlying model using pipeline config. # Build the underlying model using pipeline config.
# TODO(b/162842801): Add support for other architectures. # TODO(b/162842801): Add support for other architectures.
if pipeline_config.model.WhichOneof('model') != 'ssd': if pipeline_config.model.WhichOneof('model') == 'ssd':
raise ValueError('Only ssd models are supported in tflite. ' detection_model = model_builder.build(
pipeline_config.model, is_training=False)
ckpt = tf.train.Checkpoint(model=detection_model)
# The module helps build a TF SavedModel appropriate for TFLite conversion.
detection_module = SSDModule(pipeline_config, detection_model,
max_detections, use_regular_nms)
elif pipeline_config.model.WhichOneof('model') == 'center_net':
detection_module = CenterNetModule(
pipeline_config, max_detections, include_keypoints)
ckpt = tf.train.Checkpoint(model=detection_module.get_model())
else:
raise ValueError('Only ssd or center_net models are supported in tflite. '
'Found {} in config'.format( 'Found {} in config'.format(
pipeline_config.model.WhichOneof('model'))) pipeline_config.model.WhichOneof('model')))
detection_model = model_builder.build(
pipeline_config.model, is_training=False)
ckpt = tf.train.Checkpoint(model=detection_model)
manager = tf.train.CheckpointManager( manager = tf.train.CheckpointManager(
ckpt, trained_checkpoint_dir, max_to_keep=1) ckpt, trained_checkpoint_dir, max_to_keep=1)
status = ckpt.restore(manager.latest_checkpoint).expect_partial() status = ckpt.restore(manager.latest_checkpoint).expect_partial()
# The module helps build a TF SavedModel appropriate for TFLite conversion.
detection_module = SSDModule(pipeline_config, detection_model, max_detections,
use_regular_nms)
# Getting the concrete function traces the graph and forces variables to # Getting the concrete function traces the graph and forces variables to
# be constructed; only after this can we save the saved model. # be constructed; only after this can we save the saved model.
status.assert_existing_objects_matched() status.assert_existing_objects_matched()
......
...@@ -27,6 +27,7 @@ from object_detection.builders import model_builder ...@@ -27,6 +27,7 @@ from object_detection.builders import model_builder
from object_detection.core import model from object_detection.core import model
from object_detection.protos import pipeline_pb2 from object_detection.protos import pipeline_pb2
from object_detection.utils import tf_version from object_detection.utils import tf_version
from google.protobuf import text_format
if six.PY2: if six.PY2:
import mock # pylint: disable=g-importing-member,g-import-not-at-top import mock # pylint: disable=g-importing-member,g-import-not-at-top
...@@ -79,6 +80,10 @@ class FakeModel(model.DetectionModel): ...@@ -79,6 +80,10 @@ class FakeModel(model.DetectionModel):
tf.constant([[0, 1], [1, 0]], tf.float32), tf.constant([[0, 1], [1, 0]], tf.float32),
'num_detections': 'num_detections':
tf.constant([2, 1], tf.float32), tf.constant([2, 1], tf.float32),
'detection_keypoints':
tf.zeros([2, 17, 2], tf.float32),
'detection_keypoint_scores':
tf.zeros([2, 17], tf.float32),
} }
return postprocessed_tensors return postprocessed_tensors
...@@ -125,6 +130,49 @@ class ExportTfLiteGraphTest(tf.test.TestCase): ...@@ -125,6 +130,49 @@ class ExportTfLiteGraphTest(tf.test.TestCase):
pipeline_config.model.ssd.post_processing.batch_non_max_suppression.iou_threshold = 0.5 pipeline_config.model.ssd.post_processing.batch_non_max_suppression.iou_threshold = 0.5
return pipeline_config return pipeline_config
def _get_center_net_config(self):
pipeline_config_text = """
model {
center_net {
num_classes: 1
feature_extractor {
type: "mobilenet_v2_fpn"
}
image_resizer {
fixed_shape_resizer {
height: 10
width: 10
}
}
object_detection_task {
localization_loss {
l1_localization_loss {
}
}
}
object_center_params {
classification_loss {
}
max_box_predictions: 20
}
keypoint_estimation_task {
loss {
localization_loss {
l1_localization_loss {
}
}
classification_loss {
penalty_reduced_logistic_focal_loss {
}
}
}
}
}
}
"""
return text_format.Parse(
pipeline_config_text, pipeline_pb2.TrainEvalPipelineConfig())
# The tf.implements signature is important since it ensures MLIR legalization, # The tf.implements signature is important since it ensures MLIR legalization,
# so we test it here. # so we test it here.
def test_postprocess_implements_signature(self): def test_postprocess_implements_signature(self):
...@@ -177,7 +225,7 @@ class ExportTfLiteGraphTest(tf.test.TestCase): ...@@ -177,7 +225,7 @@ class ExportTfLiteGraphTest(tf.test.TestCase):
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel() mock_builder.return_value = FakeModel()
output_directory = os.path.join(tmp_dir, 'output') output_directory = os.path.join(tmp_dir, 'output')
expected_message = 'Only ssd models are supported in tflite' expected_message = 'Only ssd or center_net models are supported in tflite'
try: try:
export_tflite_graph_lib_tf2.export_tflite_model( export_tflite_graph_lib_tf2.export_tflite_model(
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
...@@ -240,6 +288,55 @@ class ExportTfLiteGraphTest(tf.test.TestCase): ...@@ -240,6 +288,55 @@ class ExportTfLiteGraphTest(tf.test.TestCase):
# should be 4. # should be 4.
self.assertEqual(4, len(detections)) self.assertEqual(4, len(detections))
def test_center_net_inference_object_detection(self):
tmp_dir = self.get_temp_dir()
output_directory = os.path.join(tmp_dir, 'output')
self._save_checkpoint_from_mock_model(tmp_dir)
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel()
export_tflite_graph_lib_tf2.export_tflite_model(
pipeline_config=self._get_center_net_config(),
trained_checkpoint_dir=tmp_dir,
output_directory=output_directory,
max_detections=10,
use_regular_nms=False)
saved_model_path = os.path.join(output_directory, 'saved_model')
detect_fn = tf.saved_model.load(saved_model_path)
detect_fn_sig = detect_fn.signatures['serving_default']
image = tf.zeros(shape=[1, 10, 10, 3], dtype=tf.float32)
detections = detect_fn_sig(image)
# The exported graph doesn't have numerically correct outputs, but there
# should be 4.
self.assertEqual(4, len(detections))
def test_center_net_inference_keypoint(self):
tmp_dir = self.get_temp_dir()
output_directory = os.path.join(tmp_dir, 'output')
self._save_checkpoint_from_mock_model(tmp_dir)
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel()
export_tflite_graph_lib_tf2.export_tflite_model(
pipeline_config=self._get_center_net_config(),
trained_checkpoint_dir=tmp_dir,
output_directory=output_directory,
max_detections=10,
use_regular_nms=False,
include_keypoints=True)
saved_model_path = os.path.join(output_directory, 'saved_model')
detect_fn = tf.saved_model.load(saved_model_path)
detect_fn_sig = detect_fn.signatures['serving_default']
image = tf.zeros(shape=[1, 10, 10, 3], dtype=tf.float32)
detections = detect_fn_sig(image)
# The exported graph doesn't have numerically correct outputs, but there
# should be 6 (4 for boxes, 2 for keypoints).
self.assertEqual(6, len(detections))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -94,13 +94,19 @@ flags.DEFINE_string('output_directory', None, 'Path to write outputs.') ...@@ -94,13 +94,19 @@ flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
flags.DEFINE_string( flags.DEFINE_string(
'config_override', '', 'pipeline_pb2.TrainEvalPipelineConfig ' 'config_override', '', 'pipeline_pb2.TrainEvalPipelineConfig '
'text proto to override pipeline_config_path.') 'text proto to override pipeline_config_path.')
# SSD-specific flags flags.DEFINE_integer('max_detections', 10,
flags.DEFINE_integer('ssd_max_detections', 10,
'Maximum number of detections (boxes) to return.') 'Maximum number of detections (boxes) to return.')
# SSD-specific flags
flags.DEFINE_bool( flags.DEFINE_bool(
'ssd_use_regular_nms', False, 'ssd_use_regular_nms', False,
'Flag to set postprocessing op to use Regular NMS instead of Fast NMS ' 'Flag to set postprocessing op to use Regular NMS instead of Fast NMS '
'(Default false).') '(Default false).')
# CenterNet-specific flags
flags.DEFINE_bool(
'centernet_include_keypoints', False,
'Whether to export the predicted keypoint tensors. Only CenterNet model'
' supports this flag.'
)
def main(argv): def main(argv):
...@@ -115,11 +121,10 @@ def main(argv): ...@@ -115,11 +121,10 @@ def main(argv):
text_format.Parse(f.read(), pipeline_config) text_format.Parse(f.read(), pipeline_config)
text_format.Parse(FLAGS.config_override, pipeline_config) text_format.Parse(FLAGS.config_override, pipeline_config)
export_tflite_graph_lib_tf2.export_tflite_model(pipeline_config, export_tflite_graph_lib_tf2.export_tflite_model(
FLAGS.trained_checkpoint_dir, pipeline_config, FLAGS.trained_checkpoint_dir, FLAGS.output_directory,
FLAGS.output_directory, FLAGS.max_detections, FLAGS.ssd_use_regular_nms,
FLAGS.ssd_max_detections, FLAGS.centernet_include_keypoints)
FLAGS.ssd_use_regular_nms)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment