"docs/zh_cn/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "8708851eca17ad5d61307d7a08b702ad3e77bb4e"
Commit f57fa41e authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by TF Object Detection Team
Browse files

Fix a bug of slicing boxes in _postprocess_keypoints_for_class_and_image and...

Fix a bug of slicing boxes in _postprocess_keypoints_for_class_and_image and switch to use python assert in _get_shape.

PiperOrigin-RevId: 338693370
parent 414b7b74
...@@ -184,7 +184,7 @@ def _to_float32(x): ...@@ -184,7 +184,7 @@ def _to_float32(x):
def _get_shape(tensor, num_dims): def _get_shape(tensor, num_dims):
tf.Assert(tensor.get_shape().ndims == num_dims, [tensor]) assert len(tensor.shape.as_list()) == num_dims
return shape_utils.combined_static_and_dynamic_shape(tensor) return shape_utils.combined_static_and_dynamic_shape(tensor)
...@@ -3303,13 +3303,14 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3303,13 +3303,14 @@ class CenterNetMetaArch(model.DetectionModel):
keypoint_regression = keypoint_regression[batch_index:batch_index+1, ...] keypoint_regression = keypoint_regression[batch_index:batch_index+1, ...]
y_indices = y_indices[batch_index:batch_index+1, ...] y_indices = y_indices[batch_index:batch_index+1, ...]
x_indices = x_indices[batch_index:batch_index+1, ...] x_indices = x_indices[batch_index:batch_index+1, ...]
boxes_slice = boxes[batch_index:batch_index+1, ...]
# Gather the feature map locations corresponding to the object class. # Gather the feature map locations corresponding to the object class.
y_indices_for_kpt_class = tf.gather(y_indices, indices_with_kpt_class, y_indices_for_kpt_class = tf.gather(y_indices, indices_with_kpt_class,
axis=1) axis=1)
x_indices_for_kpt_class = tf.gather(x_indices, indices_with_kpt_class, x_indices_for_kpt_class = tf.gather(x_indices, indices_with_kpt_class,
axis=1) axis=1)
boxes_for_kpt_class = tf.gather(boxes, indices_with_kpt_class, axis=1) boxes_for_kpt_class = tf.gather(boxes_slice, indices_with_kpt_class, axis=1)
# Gather the regressed keypoints. Final tensor has shape # Gather the regressed keypoints. Final tensor has shape
# [1, num_instances, num_keypoints, 2]. # [1, num_instances, num_keypoints, 2].
...@@ -3334,8 +3335,11 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3334,8 +3335,11 @@ class CenterNetMetaArch(model.DetectionModel):
# [1, num_instances, num_keypoints, 2] and # [1, num_instances, num_keypoints, 2] and
# [1, num_instances, num_keypoints], respectively. # [1, num_instances, num_keypoints], respectively.
refined_keypoints, refined_scores = refine_keypoints( refined_keypoints, refined_scores = refine_keypoints(
regressed_keypoints_for_objects, keypoint_candidates, keypoint_scores, regressed_keypoints=regressed_keypoints_for_objects,
num_keypoint_candidates, bboxes=boxes_for_kpt_class, keypoint_candidates=keypoint_candidates,
keypoint_scores=keypoint_scores,
num_keypoint_candidates=num_keypoint_candidates,
bboxes=boxes_for_kpt_class,
unmatched_keypoint_score=kp_params.unmatched_keypoint_score, unmatched_keypoint_score=kp_params.unmatched_keypoint_score,
box_scale=kp_params.box_scale, box_scale=kp_params.box_scale,
candidate_search_scale=kp_params.candidate_search_scale, candidate_search_scale=kp_params.candidate_search_scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment