Commit 0c85c06c authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Extended CenterNet model to predict keypoint depth information.

PiperOrigin-RevId: 359344675
parent 3cfd0ba0
...@@ -868,7 +868,10 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict): ...@@ -868,7 +868,10 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
candidate_search_scale=kp_config.candidate_search_scale, candidate_search_scale=kp_config.candidate_search_scale,
candidate_ranking_mode=kp_config.candidate_ranking_mode, candidate_ranking_mode=kp_config.candidate_ranking_mode,
offset_peak_radius=kp_config.offset_peak_radius, offset_peak_radius=kp_config.offset_peak_radius,
per_keypoint_offset=kp_config.per_keypoint_offset) per_keypoint_offset=kp_config.per_keypoint_offset,
predict_depth=kp_config.predict_depth,
per_keypoint_depth=kp_config.per_keypoint_depth,
keypoint_depth_loss_weight=kp_config.keypoint_depth_loss_weight)
def object_detection_proto_to_params(od_config): def object_detection_proto_to_params(od_config):
......
...@@ -116,6 +116,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -116,6 +116,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
candidate_ranking_mode: "score_distance_ratio" candidate_ranking_mode: "score_distance_ratio"
offset_peak_radius: 3 offset_peak_radius: 3
per_keypoint_offset: true per_keypoint_offset: true
predict_depth: true
per_keypoint_depth: true
keypoint_depth_loss_weight: 0.3
""" """
config = text_format.Merge(task_proto_txt, config = text_format.Merge(task_proto_txt,
center_net_pb2.CenterNet.KeypointEstimation()) center_net_pb2.CenterNet.KeypointEstimation())
...@@ -264,6 +267,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -264,6 +267,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
self.assertEqual(kp_params.candidate_ranking_mode, 'score_distance_ratio') self.assertEqual(kp_params.candidate_ranking_mode, 'score_distance_ratio')
self.assertEqual(kp_params.offset_peak_radius, 3) self.assertEqual(kp_params.offset_peak_radius, 3)
self.assertEqual(kp_params.per_keypoint_offset, True) self.assertEqual(kp_params.per_keypoint_offset, True)
self.assertEqual(kp_params.predict_depth, True)
self.assertEqual(kp_params.per_keypoint_depth, True)
self.assertAlmostEqual(kp_params.keypoint_depth_loss_weight, 0.3)
# Check mask related parameters. # Check mask related parameters.
self.assertAlmostEqual(model._mask_params.task_loss_weight, 0.7) self.assertAlmostEqual(model._mask_params.task_loss_weight, 0.7)
......
...@@ -165,6 +165,21 @@ message CenterNet { ...@@ -165,6 +165,21 @@ message CenterNet {
// out_height, out_width, 2 * num_keypoints] (recommended when the // out_height, out_width, 2 * num_keypoints] (recommended when the
// offset_peak_radius is not zero). // offset_peak_radius is not zero).
optional bool per_keypoint_offset = 18 [default = false]; optional bool per_keypoint_offset = 18 [default = false];
// Indicates whether to predict the depth of each keypoints. Note that this
// is only supported in the single class keypoint task.
optional bool predict_depth = 19 [default = false];
// Indicates whether to predict depths for each keypoint channel
// separately. If set False, the output depth target has the shape
// [batch_size, out_height, out_width, 1]. If set True, the output depth
// target has the shape [batch_size, out_height, out_width,
// num_keypoints]. Recommend to set this value and "per_keypoint_offset" to
// both be True at the same time.
optional bool per_keypoint_depth = 20 [default = false];
// The weight of the keypoint depth loss.
optional float keypoint_depth_loss_weight = 21 [default = 1.0];
} }
repeated KeypointEstimation keypoint_estimation_task = 7; repeated KeypointEstimation keypoint_estimation_task = 7;
...@@ -278,7 +293,6 @@ message CenterNet { ...@@ -278,7 +293,6 @@ message CenterNet {
// from CenterNet. Use this optional parameter to apply traditional non max // from CenterNet. Use this optional parameter to apply traditional non max
// suppression and score thresholding. // suppression and score thresholding.
optional PostProcessing post_processing = 24; optional PostProcessing post_processing = 24;
} }
message CenterNetFeatureExtractor { message CenterNetFeatureExtractor {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment