Unverified Commit eab8cfa2 authored by Danila Rukhovich's avatar Danila Rukhovich Committed by GitHub
Browse files

[Fix] NMS for Point RCNN (#1418)

* fix nms for point rcnn

* add else case
parent 023c88b5
......@@ -3,6 +3,7 @@ import torch
from mmcv.runner import BaseModule, force_fp32
from torch import nn as nn
from mmdet3d.core import xywhr2xyxyr
from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
LiDARInstance3DBoxes)
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
......@@ -320,29 +321,33 @@ class PointRPNHead(BaseModule):
else:
raise NotImplementedError('Unsupported bbox type!')
bbox = bbox.tensor[nonempty_box_mask]
bbox = bbox[nonempty_box_mask]
if self.test_cfg.score_thr is not None:
score_thr = self.test_cfg.score_thr
keep = (obj_scores >= score_thr)
obj_scores = obj_scores[keep]
sem_scores = sem_scores[keep]
bbox = bbox[keep]
bbox = bbox.tensor[keep]
if obj_scores.shape[0] > 0:
topk = min(nms_cfg.nms_pre, obj_scores.shape[0])
obj_scores_nms, indices = torch.topk(obj_scores, k=topk)
bbox_for_nms = bbox[indices]
bbox_for_nms = xywhr2xyxyr(bbox[indices].bev)
sem_scores_nms = sem_scores[indices]
keep = nms_func(bbox_for_nms[:, 0:7], obj_scores_nms,
nms_cfg.iou_thr)
keep = nms_func(bbox_for_nms, obj_scores_nms, nms_cfg.iou_thr)
keep = keep[:nms_cfg.nms_post]
bbox_selected = bbox_for_nms[keep]
bbox_selected = bbox.tensor[indices][keep]
score_selected = obj_scores_nms[keep]
cls_preds = sem_scores_nms[keep]
labels = torch.argmax(cls_preds, -1)
else:
bbox_selected = bbox.tensor
score_selected = obj_scores.new_zeros([0])
labels = obj_scores.new_zeros([0])
cls_preds = obj_scores.new_zeros([0, sem_scores.shape[-1]])
return bbox_selected, score_selected, labels, cls_preds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment