Commit 0c85c06c authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Extended CenterNet model to predict keypoint depth information.

PiperOrigin-RevId: 359344675
parent 3cfd0ba0
......@@ -868,7 +868,10 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
candidate_search_scale=kp_config.candidate_search_scale,
candidate_ranking_mode=kp_config.candidate_ranking_mode,
offset_peak_radius=kp_config.offset_peak_radius,
per_keypoint_offset=kp_config.per_keypoint_offset)
per_keypoint_offset=kp_config.per_keypoint_offset,
predict_depth=kp_config.predict_depth,
per_keypoint_depth=kp_config.per_keypoint_depth,
keypoint_depth_loss_weight=kp_config.keypoint_depth_loss_weight)
def object_detection_proto_to_params(od_config):
......
......@@ -116,6 +116,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
candidate_ranking_mode: "score_distance_ratio"
offset_peak_radius: 3
per_keypoint_offset: true
predict_depth: true
per_keypoint_depth: true
keypoint_depth_loss_weight: 0.3
"""
config = text_format.Merge(task_proto_txt,
center_net_pb2.CenterNet.KeypointEstimation())
......@@ -264,6 +267,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
self.assertEqual(kp_params.candidate_ranking_mode, 'score_distance_ratio')
self.assertEqual(kp_params.offset_peak_radius, 3)
self.assertEqual(kp_params.per_keypoint_offset, True)
self.assertEqual(kp_params.predict_depth, True)
self.assertEqual(kp_params.per_keypoint_depth, True)
self.assertAlmostEqual(kp_params.keypoint_depth_loss_weight, 0.3)
# Check mask related parameters.
self.assertAlmostEqual(model._mask_params.task_loss_weight, 0.7)
......
......@@ -165,6 +165,21 @@ message CenterNet {
// out_height, out_width, 2 * num_keypoints] (recommended when the
// offset_peak_radius is not zero).
optional bool per_keypoint_offset = 18 [default = false];
// Indicates whether to predict the depth of each keypoints. Note that this
// is only supported in the single class keypoint task.
optional bool predict_depth = 19 [default = false];
// Indicates whether to predict depths for each keypoint channel
// separately. If set False, the output depth target has the shape
// [batch_size, out_height, out_width, 1]. If set True, the output depth
// target has the shape [batch_size, out_height, out_width,
// num_keypoints]. Recommend to set this value and "per_keypoint_offset" to
// both be True at the same time.
optional bool per_keypoint_depth = 20 [default = false];
// The weight of the keypoint depth loss.
optional float keypoint_depth_loss_weight = 21 [default = 1.0];
}
repeated KeypointEstimation keypoint_estimation_task = 7;
......@@ -278,7 +293,6 @@ message CenterNet {
// from CenterNet. Use this optional parameter to apply traditional non max
// suppression and score thresholding.
optional PostProcessing post_processing = 24;
}
message CenterNetFeatureExtractor {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment