Commit 306dafc6 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the model export script such that it is easier for the external users

to export CenterNet model by passing in flags  without changing the pipeline configuration.

PiperOrigin-RevId: 358432613
parent aeab8791
...@@ -207,7 +207,8 @@ class CenterNetModule(tf.Module): ...@@ -207,7 +207,8 @@ class CenterNetModule(tf.Module):
both object detection and keypoint estimation task. both object detection and keypoint estimation task.
""" """
def __init__(self, pipeline_config, max_detections, include_keypoints): def __init__(self, pipeline_config, max_detections, include_keypoints,
label_map_path=''):
"""Initialization. """Initialization.
Args: Args:
...@@ -215,10 +216,15 @@ class CenterNetModule(tf.Module): ...@@ -215,10 +216,15 @@ class CenterNetModule(tf.Module):
max_detections: Max detections desired from the TFLite model. max_detections: Max detections desired from the TFLite model.
include_keypoints: If set true, the output dictionary will include the include_keypoints: If set true, the output dictionary will include the
keypoint coordinates and keypoint confidence scores. keypoint coordinates and keypoint confidence scores.
label_map_path: Path to the label map which is used by CenterNet keypoint
estimation task. If provided, the label_map_path in the configuration
will be replaced by this one.
""" """
self._max_detections = max_detections self._max_detections = max_detections
self._include_keypoints = include_keypoints self._include_keypoints = include_keypoints
self._process_config(pipeline_config) self._process_config(pipeline_config)
if include_keypoints and label_map_path:
pipeline_config.model.center_net.keypoint_label_map_path = label_map_path
self._pipeline_config = pipeline_config self._pipeline_config = pipeline_config
self._model = model_builder.build( self._model = model_builder.build(
self._pipeline_config.model, is_training=False) self._pipeline_config.model, is_training=False)
...@@ -303,7 +309,7 @@ class CenterNetModule(tf.Module): ...@@ -303,7 +309,7 @@ class CenterNetModule(tf.Module):
def export_tflite_model(pipeline_config, trained_checkpoint_dir, def export_tflite_model(pipeline_config, trained_checkpoint_dir,
output_directory, max_detections, use_regular_nms, output_directory, max_detections, use_regular_nms,
include_keypoints=False): include_keypoints=False, label_map_path=''):
"""Exports inference SavedModel for TFLite conversion. """Exports inference SavedModel for TFLite conversion.
NOTE: Only supports SSD meta-architectures for now, and the output model will NOTE: Only supports SSD meta-architectures for now, and the output model will
...@@ -322,6 +328,9 @@ def export_tflite_model(pipeline_config, trained_checkpoint_dir, ...@@ -322,6 +328,9 @@ def export_tflite_model(pipeline_config, trained_checkpoint_dir,
Note that this argument is only used by the SSD model. Note that this argument is only used by the SSD model.
include_keypoints: Decides whether to also output the keypoint predictions. include_keypoints: Decides whether to also output the keypoint predictions.
Note that this argument is only used by the CenterNet model. Note that this argument is only used by the CenterNet model.
label_map_path: Path to the label map which is used by CenterNet keypoint
estimation task. If provided, the label_map_path in the configuration will
be replaced by this one.
Raises: Raises:
ValueError: if pipeline is invalid. ValueError: if pipeline is invalid.
...@@ -339,7 +348,8 @@ def export_tflite_model(pipeline_config, trained_checkpoint_dir, ...@@ -339,7 +348,8 @@ def export_tflite_model(pipeline_config, trained_checkpoint_dir,
max_detections, use_regular_nms) max_detections, use_regular_nms)
elif pipeline_config.model.WhichOneof('model') == 'center_net': elif pipeline_config.model.WhichOneof('model') == 'center_net':
detection_module = CenterNetModule( detection_module = CenterNetModule(
pipeline_config, max_detections, include_keypoints) pipeline_config, max_detections, include_keypoints,
label_map_path=label_map_path)
ckpt = tf.train.Checkpoint(model=detection_module.get_model()) ckpt = tf.train.Checkpoint(model=detection_module.get_model())
else: else:
raise ValueError('Only ssd or center_net models are supported in tflite. ' raise ValueError('Only ssd or center_net models are supported in tflite. '
......
...@@ -53,7 +53,7 @@ certain fields in the provided pipeline_config_path. These are useful for ...@@ -53,7 +53,7 @@ certain fields in the provided pipeline_config_path. These are useful for
making small changes to the inference graph that differ from the training or making small changes to the inference graph that differ from the training or
eval config. eval config.
Example Usage (in which we change the NMS iou_threshold to be 0.5 and Example Usage 1 (in which we change the NMS iou_threshold to be 0.5 and
NMS score_threshold to be 0.0): NMS score_threshold to be 0.0):
python object_detection/export_tflite_model_tf2.py \ python object_detection/export_tflite_model_tf2.py \
--pipeline_config_path path/to/ssd_model/pipeline.config \ --pipeline_config_path path/to/ssd_model/pipeline.config \
...@@ -71,6 +71,27 @@ python object_detection/export_tflite_model_tf2.py \ ...@@ -71,6 +71,27 @@ python object_detection/export_tflite_model_tf2.py \
} \ } \
} \ } \
" "
Example Usage 2 (export CenterNet model for keypoint estimation task with fixed
shape resizer and customized input resolution):
python object_detection/export_tflite_model_tf2.py \
--pipeline_config_path path/to/ssd_model/pipeline.config \
--trained_checkpoint_dir path/to/ssd_model/checkpoint \
--output_directory path/to/exported_model_directory \
--keypoint_label_map_path path/to/label_map.txt \
--max_detections 10 \
--centernet_include_keypoints true \
--config_override " \
model{ \
center_net { \
image_resizer { \
fixed_shape_resizer { \
height: 320 \
width: 320 \
} \
} \
} \
}" \
""" """
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -107,6 +128,13 @@ flags.DEFINE_bool( ...@@ -107,6 +128,13 @@ flags.DEFINE_bool(
'Whether to export the predicted keypoint tensors. Only CenterNet model' 'Whether to export the predicted keypoint tensors. Only CenterNet model'
' supports this flag.' ' supports this flag.'
) )
flags.DEFINE_string(
'keypoint_label_map_path', None,
'Path of the label map used by CenterNet keypoint estimation task. If'
' provided, the label map path in the pipeline config will be replaced by'
' this one. Note that it is only used when exporting CenterNet model for'
' keypoint estimation task.'
)
def main(argv): def main(argv):
...@@ -119,12 +147,14 @@ def main(argv): ...@@ -119,12 +147,14 @@ def main(argv):
with tf.io.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f: with tf.io.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
text_format.Parse(f.read(), pipeline_config) text_format.Parse(f.read(), pipeline_config)
text_format.Parse(FLAGS.config_override, pipeline_config) override_config = pipeline_pb2.TrainEvalPipelineConfig()
text_format.Parse(FLAGS.config_override, override_config)
pipeline_config.MergeFrom(override_config)
export_tflite_graph_lib_tf2.export_tflite_model( export_tflite_graph_lib_tf2.export_tflite_model(
pipeline_config, FLAGS.trained_checkpoint_dir, FLAGS.output_directory, pipeline_config, FLAGS.trained_checkpoint_dir, FLAGS.output_directory,
FLAGS.max_detections, FLAGS.ssd_use_regular_nms, FLAGS.max_detections, FLAGS.ssd_use_regular_nms,
FLAGS.centernet_include_keypoints) FLAGS.centernet_include_keypoints, FLAGS.keypoint_label_map_path)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment