Unverified Commit f0e68404 authored by WRH's avatar WRH Committed by GitHub
Browse files

Fix some bug as well as unit test for nms_rotate (#728)

* add const to nms_rorate in pybind

* fix test nms rotated

* skip test instead of passing it

* fix lint

* update pytest skip syntax
parent 8eae7779
...@@ -178,8 +178,8 @@ Tensor top_pool_backward(Tensor input, Tensor grad_output); ...@@ -178,8 +178,8 @@ Tensor top_pool_backward(Tensor input, Tensor grad_output);
void box_iou_rotated(const Tensor boxes1, const Tensor boxes2, Tensor ious, void box_iou_rotated(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const bool aligned); const bool aligned);
Tensor nms_rotated(const Tensor dets, Tensor scores, Tensor order, Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
Tensor dets_sorted, const float iou_threshold, const Tensor dets_sorted, const float iou_threshold,
const int multi_label); const int multi_label);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
import numpy as np import numpy as np
import pytest
import torch import torch
class TestNmsRotated(object): @pytest.mark.skipif(
not torch.cuda.is_available(),
reason='GPU is required to test NMSRotated op')
class TestNmsRotated:
def test_ml_nms_rotated(self): def test_ml_nms_rotated(self):
if not torch.cuda.is_available():
return
from mmcv.ops import nms_rotated from mmcv.ops import nms_rotated
np_boxes = np.array( np_boxes = np.array(
[[6.0, 3.0, 8.0, 7.0, 0.5, 0.7], [3.0, 6.0, 9.0, 11.0, 0.6, 0.8], [[6.0, 3.0, 8.0, 7.0, 0.5, 0.7], [3.0, 6.0, 9.0, 11.0, 0.6, 0.8],
...@@ -24,14 +26,12 @@ class TestNmsRotated(object): ...@@ -24,14 +26,12 @@ class TestNmsRotated(object):
boxes = torch.from_numpy(np_boxes).cuda() boxes = torch.from_numpy(np_boxes).cuda()
labels = torch.from_numpy(np_labels).cuda() labels = torch.from_numpy(np_labels).cuda()
dets, keep_inds = nms_rotated(boxes, 0.5, labels, True) dets, keep_inds = nms_rotated(boxes[:, :5], boxes[:, -1], 0.5, labels)
assert np.allclose(dets.cpu().numpy(), np_expect_dets) assert np.allclose(dets.cpu().numpy()[:, :5], np_expect_dets)
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds) assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)
def test_nms_rotated(self): def test_nms_rotated(self):
if not torch.cuda.is_available():
return
from mmcv.ops import nms_rotated from mmcv.ops import nms_rotated
np_boxes = np.array( np_boxes = np.array(
[[6.0, 3.0, 8.0, 7.0, 0.5, 0.7], [3.0, 6.0, 9.0, 11.0, 0.6, 0.8], [[6.0, 3.0, 8.0, 7.0, 0.5, 0.7], [3.0, 6.0, 9.0, 11.0, 0.6, 0.8],
...@@ -47,6 +47,6 @@ class TestNmsRotated(object): ...@@ -47,6 +47,6 @@ class TestNmsRotated(object):
boxes = torch.from_numpy(np_boxes).cuda() boxes = torch.from_numpy(np_boxes).cuda()
dets, keep_inds = nms_rotated(boxes, 0.5) dets, keep_inds = nms_rotated(boxes[:, :5], boxes[:, -1], 0.5)
assert np.allclose(dets.cpu().numpy(), np_expect_dets) assert np.allclose(dets.cpu().numpy()[:, :5], np_expect_dets)
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds) assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment