Unverified Commit a34e1e7f authored by Danila Rukhovich's avatar Danila Rukhovich Committed by GitHub
Browse files

[Fix] Fix bugs in nms (#1479)

* fix bugs in nms

* fix test_merge_augs
parent 1219933f
...@@ -4,8 +4,6 @@ import numpy as np ...@@ -4,8 +4,6 @@ import numpy as np
import torch import torch
from mmcv.ops import nms, nms_rotated from mmcv.ops import nms, nms_rotated
from ..bbox import xywhr2xyxyr
def box3d_multiclass_nms(mlvl_bboxes, def box3d_multiclass_nms(mlvl_bboxes,
mlvl_bboxes_for_nms, mlvl_bboxes_for_nms,
...@@ -254,6 +252,8 @@ def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): ...@@ -254,6 +252,8 @@ def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
if pre_max_size is not None: if pre_max_size is not None:
order = order[:pre_max_size] order = order[:pre_max_size]
boxes = boxes[order].contiguous() boxes = boxes[order].contiguous()
scores = scores[order]
# xyxyr -> back to xywhr # xyxyr -> back to xywhr
# note: better skip this step before nms_bev call in the future # note: better skip this step before nms_bev call in the future
boxes = torch.stack( boxes = torch.stack(
...@@ -262,6 +262,7 @@ def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): ...@@ -262,6 +262,7 @@ def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
dim=-1) dim=-1)
keep = nms_rotated(boxes, scores, thresh)[1] keep = nms_rotated(boxes, scores, thresh)[1]
keep = order[keep]
if post_max_size is not None: if post_max_size is not None:
keep = keep[:post_max_size] keep = keep[:post_max_size]
return keep return keep
...@@ -284,4 +285,4 @@ def nms_normal_bev(boxes, scores, thresh): ...@@ -284,4 +285,4 @@ def nms_normal_bev(boxes, scores, thresh):
torch.Tensor: Remaining indices with scores in descending order. torch.Tensor: Remaining indices with scores in descending order.
""" """
assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]' assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]'
return nms(xywhr2xyxyr(boxes)[:, :-1], scores, thresh)[1] return nms(boxes[:, :-1], scores, thresh)[1]
...@@ -29,11 +29,17 @@ def test_merge_aug_bboxes_3d(): ...@@ -29,11 +29,17 @@ def test_merge_aug_bboxes_3d():
[2.5831, 4.8117, -1.2733, 0.5852, 0.8832, 0.9733, 1.6500], [2.5831, 4.8117, -1.2733, 0.5852, 0.8832, 0.9733, 1.6500],
[-1.0864, 1.9045, -1.2000, 0.7128, 1.5631, 2.1045, 0.1022]], [-1.0864, 1.9045, -1.2000, 0.7128, 1.5631, 2.1045, 0.1022]],
device='cuda')) device='cuda'))
labels_3d = torch.tensor([0, 7, 6]) labels_3d = torch.tensor([0, 7, 6], device='cuda')
scores_3d = torch.tensor([0.5, 1.0, 1.0]) scores_3d_1 = torch.tensor([0.3, 0.6, 0.9], device='cuda')
aug_result = dict( scores_3d_2 = torch.tensor([0.2, 0.5, 0.8], device='cuda')
boxes_3d=boxes_3d, labels_3d=labels_3d, scores_3d=scores_3d) scores_3d_3 = torch.tensor([0.1, 0.4, 0.7], device='cuda')
aug_results = [aug_result, aug_result, aug_result] aug_result_1 = dict(
boxes_3d=boxes_3d, labels_3d=labels_3d, scores_3d=scores_3d_1)
aug_result_2 = dict(
boxes_3d=boxes_3d, labels_3d=labels_3d, scores_3d=scores_3d_2)
aug_result_3 = dict(
boxes_3d=boxes_3d, labels_3d=labels_3d, scores_3d=scores_3d_3)
aug_results = [aug_result_1, aug_result_2, aug_result_3]
test_cfg = mmcv.ConfigDict( test_cfg = mmcv.ConfigDict(
use_rotate_nms=True, use_rotate_nms=True,
nms_across_levels=False, nms_across_levels=False,
...@@ -53,9 +59,8 @@ def test_merge_aug_bboxes_3d(): ...@@ -53,9 +59,8 @@ def test_merge_aug_bboxes_3d():
[1.0473, -4.1687, -1.2317, 2.3021, 1.8876, 1.9696, -1.6956], [1.0473, -4.1687, -1.2317, 2.3021, 1.8876, 1.9696, -1.6956],
[-1.0473, 4.1687, -1.2317, 2.3021, 1.8876, 1.9696, 1.4460], [-1.0473, 4.1687, -1.2317, 2.3021, 1.8876, 1.9696, 1.4460],
[2.0946, 8.3374, -2.4634, 4.6042, 3.7752, 3.9392, 1.6956]]) [2.0946, 8.3374, -2.4634, 4.6042, 3.7752, 3.9392, 1.6956]])
expected_scores_3d = torch.tensor([ expected_scores_3d = torch.tensor(
1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.5000, 0.5000, 0.5000 [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
])
expected_labels_3d = torch.tensor([6, 6, 6, 7, 7, 7, 0, 0, 0]) expected_labels_3d = torch.tensor([6, 6, 6, 7, 7, 7, 0, 0, 0])
assert torch.allclose(results['boxes_3d'].tensor, expected_boxes_3d) assert torch.allclose(results['boxes_3d'].tensor, expected_boxes_3d)
assert torch.allclose(results['scores_3d'], expected_scores_3d) assert torch.allclose(results['scores_3d'], expected_scores_3d)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numpy as np import numpy as np
import pytest
import torch import torch
...@@ -73,3 +74,41 @@ def test_circle_nms(): ...@@ -73,3 +74,41 @@ def test_circle_nms():
keep = circle_nms(boxes.numpy(), 0.175) keep = circle_nms(boxes.numpy(), 0.175)
expected_keep = [1, 2, 3, 4, 5, 6, 7, 8, 9] expected_keep = [1, 2, 3, 4, 5, 6, 7, 8, 9]
assert np.all(keep == expected_keep) assert np.all(keep == expected_keep)
# copied from tests/test_ops/test_iou3d.py from mmcv<=1.5
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_nms_bev():
from mmdet3d.core.post_processing import nms_bev
np_boxes = np.array(
[[6.0, 3.0, 8.0, 7.0, 2.0], [3.0, 6.0, 9.0, 11.0, 1.0],
[3.0, 7.0, 10.0, 12.0, 1.0], [1.0, 4.0, 13.0, 7.0, 3.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
np_inds = np.array([1, 0, 3])
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
inds = nms_bev(boxes.cuda(), scores.cuda(), thresh=0.3)
assert np.allclose(inds.cpu().numpy(), np_inds)
# copied from tests/test_ops/test_iou3d.py from mmcv<=1.5
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_nms_normal_bev():
from mmdet3d.core.post_processing import nms_normal_bev
np_boxes = np.array(
[[6.0, 3.0, 8.0, 7.0, 2.0], [3.0, 6.0, 9.0, 11.0, 1.0],
[3.0, 7.0, 10.0, 12.0, 1.0], [1.0, 4.0, 13.0, 7.0, 3.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
np_inds = np.array([1, 0, 3])
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
inds = nms_normal_bev(boxes.cuda(), scores.cuda(), thresh=0.3)
assert np.allclose(inds.cpu().numpy(), np_inds)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment