"include/vscode:/vscode.git/clone" did not exist on "67dc11971f9f7878e28496dc28382923598557ea"
Commit a79ca771 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Fixed the issue of NMS not working well with CenterNet in the multiclass

scenario. See b/218767303#comment12 for more detailed explanation.

PiperOrigin-RevId: 457641652
parent 5670119e
...@@ -4235,6 +4235,15 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -4235,6 +4235,15 @@ class CenterNetMetaArch(model.DetectionModel):
axis=-2) axis=-2)
multiclass_scores = postprocess_dict[ multiclass_scores = postprocess_dict[
fields.DetectionResultFields.detection_multiclass_scores] fields.DetectionResultFields.detection_multiclass_scores]
num_classes = tf.shape(multiclass_scores)[2]
class_mask = tf.cast(
tf.one_hot(
postprocess_dict[fields.DetectionResultFields.detection_classes],
depth=num_classes), tf.bool)
# Surpress the scores of those unselected classes to be zeros. Otherwise,
# the downstream NMS ops might be confused and introduce issues.
multiclass_scores = tf.where(
class_mask, multiclass_scores, tf.zeros_like(multiclass_scores))
num_valid_boxes = postprocess_dict.pop( num_valid_boxes = postprocess_dict.pop(
fields.DetectionResultFields.num_detections) fields.DetectionResultFields.num_detections)
# Remove scores and classes as NMS will compute these form multiclass # Remove scores and classes as NMS will compute these form multiclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment