Unverified Commit eab8cfa2 authored by Danila Rukhovich's avatar Danila Rukhovich Committed by GitHub
Browse files

[Fix] NMS for Point RCNN (#1418)

* fix nms for point rcnn

* add else case
parent 023c88b5
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
from mmcv.runner import BaseModule, force_fp32 from mmcv.runner import BaseModule, force_fp32
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import xywhr2xyxyr
from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes, from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
LiDARInstance3DBoxes) LiDARInstance3DBoxes)
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
...@@ -320,29 +321,33 @@ class PointRPNHead(BaseModule): ...@@ -320,29 +321,33 @@ class PointRPNHead(BaseModule):
else: else:
raise NotImplementedError('Unsupported bbox type!') raise NotImplementedError('Unsupported bbox type!')
bbox = bbox.tensor[nonempty_box_mask] bbox = bbox[nonempty_box_mask]
if self.test_cfg.score_thr is not None: if self.test_cfg.score_thr is not None:
score_thr = self.test_cfg.score_thr score_thr = self.test_cfg.score_thr
keep = (obj_scores >= score_thr) keep = (obj_scores >= score_thr)
obj_scores = obj_scores[keep] obj_scores = obj_scores[keep]
sem_scores = sem_scores[keep] sem_scores = sem_scores[keep]
bbox = bbox[keep] bbox = bbox.tensor[keep]
if obj_scores.shape[0] > 0: if obj_scores.shape[0] > 0:
topk = min(nms_cfg.nms_pre, obj_scores.shape[0]) topk = min(nms_cfg.nms_pre, obj_scores.shape[0])
obj_scores_nms, indices = torch.topk(obj_scores, k=topk) obj_scores_nms, indices = torch.topk(obj_scores, k=topk)
bbox_for_nms = bbox[indices] bbox_for_nms = xywhr2xyxyr(bbox[indices].bev)
sem_scores_nms = sem_scores[indices] sem_scores_nms = sem_scores[indices]
keep = nms_func(bbox_for_nms[:, 0:7], obj_scores_nms, keep = nms_func(bbox_for_nms, obj_scores_nms, nms_cfg.iou_thr)
nms_cfg.iou_thr)
keep = keep[:nms_cfg.nms_post] keep = keep[:nms_cfg.nms_post]
bbox_selected = bbox_for_nms[keep] bbox_selected = bbox.tensor[indices][keep]
score_selected = obj_scores_nms[keep] score_selected = obj_scores_nms[keep]
cls_preds = sem_scores_nms[keep] cls_preds = sem_scores_nms[keep]
labels = torch.argmax(cls_preds, -1) labels = torch.argmax(cls_preds, -1)
else:
bbox_selected = bbox.tensor
score_selected = obj_scores.new_zeros([0])
labels = obj_scores.new_zeros([0])
cls_preds = obj_scores.new_zeros([0, sem_scores.shape[-1]])
return bbox_selected, score_selected, labels, cls_preds return bbox_selected, score_selected, labels, cls_preds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment