Commit 3b2bcdaf authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed to test with multi-class nms on multihead

parent a236428b
...@@ -230,6 +230,7 @@ class Detector3DTemplate(nn.Module): ...@@ -230,6 +230,7 @@ class Detector3DTemplate(nn.Module):
pred_scores.append(cur_pred_scores) pred_scores.append(cur_pred_scores)
pred_labels.append(cur_pred_labels) pred_labels.append(cur_pred_labels)
pred_boxes.append(cur_pred_boxes) pred_boxes.append(cur_pred_boxes)
cur_start_idx += cur_cls_preds.shape[0]
final_scores = torch.cat(pred_scores, dim=0) final_scores = torch.cat(pred_scores, dim=0)
final_labels = torch.cat(pred_labels, dim=0) final_labels = torch.cat(pred_labels, dim=0)
......
...@@ -36,29 +36,26 @@ def multi_classes_nms(cls_scores, box_preds, nms_config, score_thresh=None): ...@@ -36,29 +36,26 @@ def multi_classes_nms(cls_scores, box_preds, nms_config, score_thresh=None):
""" """
pred_scores, pred_labels, pred_boxes = [], [], [] pred_scores, pred_labels, pred_boxes = [], [], []
for k in range(cls_scores.shape[0]): for k in range(cls_scores.shape[1]):
if score_thresh is not None: if score_thresh is not None:
scores_mask = (cls_scores[:, k] >= score_thresh) scores_mask = (cls_scores[:, k] >= score_thresh)
box_scores = cls_scores[scores_mask, k] box_scores = cls_scores[scores_mask, k]
box_preds = box_preds[scores_mask] cur_box_preds = box_preds[scores_mask]
else: else:
box_scores = cls_scores[:, k] box_scores = cls_scores[:, k]
selected = [] selected = []
if box_scores.shape[0] > 0: if box_scores.shape[0] > 0:
box_scores_nms, indices = torch.topk(box_scores, k=min(nms_config.NMS_PRE_MAXSIZE, box_scores.shape[0])) box_scores_nms, indices = torch.topk(box_scores, k=min(nms_config.NMS_PRE_MAXSIZE, box_scores.shape[0]))
boxes_for_nms = box_preds[indices] boxes_for_nms = cur_box_preds[indices]
keep_idx, selected_scores = getattr(iou3d_nms_utils, nms_config.NMS_TYPE)( keep_idx, selected_scores = getattr(iou3d_nms_utils, nms_config.NMS_TYPE)(
boxes_for_nms[:, 0:7], box_scores_nms, nms_config.NMS_THRESH, **nms_config boxes_for_nms[:, 0:7], box_scores_nms, nms_config.NMS_THRESH, **nms_config
) )
selected = indices[keep_idx[:nms_config.NMS_POST_MAXSIZE]] selected = indices[keep_idx[:nms_config.NMS_POST_MAXSIZE]]
if score_thresh is not None:
selected = scores_mask.nonzero().view(-1)
pred_scores.append(box_scores[selected]) pred_scores.append(box_scores[selected])
pred_labels.append(box_scores.new_ones(selected.shape[0]) * k) pred_labels.append(box_scores.new_ones(len(selected)).long() * k)
pred_boxes.append(box_preds[selected]) pred_boxes.append(cur_box_preds[selected])
pred_scores = torch.cat(pred_scores, dim=0) pred_scores = torch.cat(pred_scores, dim=0)
pred_labels = torch.cat(pred_labels, dim=0) pred_labels = torch.cat(pred_labels, dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment