Commit af1a6c57 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Add CenterNet option to expose predict() outputs in postprocess.

PiperOrigin-RevId: 437824318
parent f8f4ab71
...@@ -1171,7 +1171,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries): ...@@ -1171,7 +1171,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
temporal_offset_params=temporal_offset_params, temporal_offset_params=temporal_offset_params,
use_depthwise=center_net_config.use_depthwise, use_depthwise=center_net_config.use_depthwise,
compute_heatmap_sparse=center_net_config.compute_heatmap_sparse, compute_heatmap_sparse=center_net_config.compute_heatmap_sparse,
non_max_suppression_fn=non_max_suppression_fn) non_max_suppression_fn=non_max_suppression_fn,
output_prediction_dict=center_net_config.output_prediction_dict)
def _build_center_net_feature_extractor(feature_extractor_config, is_training): def _build_center_net_feature_extractor(feature_extractor_config, is_training):
......
...@@ -2676,7 +2676,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2676,7 +2676,8 @@ class CenterNetMetaArch(model.DetectionModel):
use_depthwise=False, use_depthwise=False,
compute_heatmap_sparse=False, compute_heatmap_sparse=False,
non_max_suppression_fn=None, non_max_suppression_fn=None,
unit_height_conv=False): unit_height_conv=False,
output_prediction_dict=False):
"""Initializes a CenterNet model. """Initializes a CenterNet model.
Args: Args:
...@@ -2722,6 +2723,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2722,6 +2723,8 @@ class CenterNetMetaArch(model.DetectionModel):
non_max_suppression_fn: Optional Non Max Suppression function to apply. non_max_suppression_fn: Optional Non Max Suppression function to apply.
unit_height_conv: If True, Conv2Ds in prediction heads have asymmetric unit_height_conv: If True, Conv2Ds in prediction heads have asymmetric
kernels with height=1. kernels with height=1.
output_prediction_dict: If true, combines all items from the dictionary
returned by predict() function into the output of postprocess().
""" """
assert object_detection_params or keypoint_params_dict assert object_detection_params or keypoint_params_dict
# Shorten the name for convenience and better formatting. # Shorten the name for convenience and better formatting.
...@@ -2747,6 +2750,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2747,6 +2750,7 @@ class CenterNetMetaArch(model.DetectionModel):
self._use_depthwise = use_depthwise self._use_depthwise = use_depthwise
self._compute_heatmap_sparse = compute_heatmap_sparse self._compute_heatmap_sparse = compute_heatmap_sparse
self._output_prediction_dict = output_prediction_dict
# subclasses may not implement the unit_height_conv arg, so only provide it # subclasses may not implement the unit_height_conv arg, so only provide it
# as a kwarg if it is True. # as a kwarg if it is True.
...@@ -4110,6 +4114,10 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -4110,6 +4114,10 @@ class CenterNetMetaArch(model.DetectionModel):
fields.DetectionResultFields.num_detections: num_detections, fields.DetectionResultFields.num_detections: num_detections,
} }
if self._output_prediction_dict:
postprocess_dict.update(prediction_dict)
postprocess_dict['true_image_shapes'] = true_image_shapes
boxes_strided = None boxes_strided = None
if self._od_params: if self._od_params:
boxes_strided = ( boxes_strided = (
...@@ -4122,7 +4130,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -4122,7 +4130,7 @@ class CenterNetMetaArch(model.DetectionModel):
postprocess_dict.update({ postprocess_dict.update({
fields.DetectionResultFields.detection_boxes: boxes, fields.DetectionResultFields.detection_boxes: boxes,
'detection_boxes_strided': boxes_strided 'detection_boxes_strided': boxes_strided,
}) })
if self._kp_params_dict: if self._kp_params_dict:
......
...@@ -11,7 +11,7 @@ import "object_detection/protos/preprocessor.proto"; ...@@ -11,7 +11,7 @@ import "object_detection/protos/preprocessor.proto";
// Points" paper [1] // Points" paper [1]
// [1]: https://arxiv.org/abs/1904.07850 // [1]: https://arxiv.org/abs/1904.07850
// Next Id = 16 // Next Id = 26
message CenterNet { message CenterNet {
// Number of classes to predict. // Number of classes to predict.
optional int32 num_classes = 1; optional int32 num_classes = 1;
...@@ -504,6 +504,7 @@ message CenterNet { ...@@ -504,6 +504,7 @@ message CenterNet {
// within error bars. // within error bars.
optional bool use_only_last_stage = 24 [default = false]; optional bool use_only_last_stage = 24 [default = false];
} }
optional DeepMACMaskEstimation deepmac_mask_estimation = 14; optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
...@@ -514,6 +515,10 @@ message CenterNet { ...@@ -514,6 +515,10 @@ message CenterNet {
// from CenterNet. Use this optional parameter to apply traditional non max // from CenterNet. Use this optional parameter to apply traditional non max
// suppression and score thresholding. // suppression and score thresholding.
optional PostProcessing post_processing = 24; optional PostProcessing post_processing = 24;
// If set, dictionary items returned by the predict() function
// are appended to the output of postprocess().
optional bool output_prediction_dict = 25 [default = false];
} }
enum LossNormalize { enum LossNormalize {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment