Commit 3b2bcdaf authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed to test with multi-class nms on multihead

parent a236428b
......@@ -230,6 +230,7 @@ class Detector3DTemplate(nn.Module):
pred_scores.append(cur_pred_scores)
pred_labels.append(cur_pred_labels)
pred_boxes.append(cur_pred_boxes)
cur_start_idx += cur_cls_preds.shape[0]
final_scores = torch.cat(pred_scores, dim=0)
final_labels = torch.cat(pred_labels, dim=0)
......
......@@ -36,29 +36,26 @@ def multi_classes_nms(cls_scores, box_preds, nms_config, score_thresh=None):
"""
pred_scores, pred_labels, pred_boxes = [], [], []
for k in range(cls_scores.shape[0]):
for k in range(cls_scores.shape[1]):
if score_thresh is not None:
scores_mask = (cls_scores[:, k] >= score_thresh)
box_scores = cls_scores[scores_mask, k]
box_preds = box_preds[scores_mask]
cur_box_preds = box_preds[scores_mask]
else:
box_scores = cls_scores[:, k]
selected = []
if box_scores.shape[0] > 0:
box_scores_nms, indices = torch.topk(box_scores, k=min(nms_config.NMS_PRE_MAXSIZE, box_scores.shape[0]))
boxes_for_nms = box_preds[indices]
boxes_for_nms = cur_box_preds[indices]
keep_idx, selected_scores = getattr(iou3d_nms_utils, nms_config.NMS_TYPE)(
boxes_for_nms[:, 0:7], box_scores_nms, nms_config.NMS_THRESH, **nms_config
)
selected = indices[keep_idx[:nms_config.NMS_POST_MAXSIZE]]
if score_thresh is not None:
selected = scores_mask.nonzero().view(-1)
pred_scores.append(box_scores[selected])
pred_labels.append(box_scores.new_ones(selected.shape[0]) * k)
pred_boxes.append(box_preds[selected])
pred_labels.append(box_scores.new_ones(len(selected)).long() * k)
pred_boxes.append(cur_box_preds[selected])
pred_scores = torch.cat(pred_scores, dim=0)
pred_labels = torch.cat(pred_labels, dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment