Unverified Commit b6167d59 authored by zhanggefan's avatar zhanggefan Committed by GitHub
Browse files

[Enhancement] Add clockwise argument to ops box_iou_rotated and nms_rotated (#1592)



* add clockwise arguments to ops box_iou_rotated and nms_rotated

* refactor docs

* change code that may incur stopped gradient.

* refactor docs

* Update mmcv/ops/nms.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/box_iou_rotated.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent b586cc2f
......@@ -4,7 +4,11 @@ from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated'])
def box_iou_rotated(bboxes1, bboxes2, mode='iou', aligned=False):
def box_iou_rotated(bboxes1,
bboxes2,
mode='iou',
aligned=False,
clockwise=True):
"""Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in
......@@ -14,6 +18,94 @@ def box_iou_rotated(bboxes1, bboxes2, mode='iou', aligned=False):
of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
bboxes1 and bboxes2.
.. note::
The operator assumes:
1) The positive direction along x axis is left -> right.
2) The positive direction along y axis is top -> down.
3) The w border is in parallel with x axis when angle = 0.
However, there are 2 opposite definitions of the positive angular
direction, clockwise (CW) and counter-clockwise (CCW). MMCV supports
both definitions and uses CW by default.
Please set ``clockwise=False`` if you are using the CCW definition.
The coordinate system when ``clockwise`` is ``True`` (default)
.. code-block:: none
0-------------------> x (0 rad)
| A-------------B
| | |
| | box h
| | angle=0 |
| D------w------C
v
y (pi/2 rad)
In such coordination system the rotation matrix is
.. math::
\\begin{pmatrix}
\\cos\\alpha & -\\sin\\alpha \\\\
\\sin\\alpha & \\cos\\alpha
\\end{pmatrix}
The coordinates of the corner point A can be calculated as:
.. math::
P_A=
\\begin{pmatrix} x_A \\\\ y_A\\end{pmatrix}
=
\\begin{pmatrix} x_{center} \\\\ y_{center}\\end{pmatrix} +
\\begin{pmatrix}\\cos\\alpha & -\\sin\\alpha \\\\
\\sin\\alpha & \\cos\\alpha\\end{pmatrix}
\\begin{pmatrix} -0.5w \\\\ -0.5h\\end{pmatrix} \\\\
=
\\begin{pmatrix} x_{center}-0.5w\\cos\\alpha+0.5h\\sin\\alpha
\\\\
y_{center}-0.5w\\sin\\alpha-0.5h\\cos\\alpha\\end{pmatrix}
The coordinate system when ``clockwise`` is ``False``
.. code-block:: none
0-------------------> x (0 rad)
| A-------------B
| | |
| | box h
| | angle=0 |
| D------w------C
v
y (-pi/2 rad)
In such coordination system the rotation matrix is
.. math::
\\begin{pmatrix}
\\cos\\alpha & \\sin\\alpha \\\\
-\\sin\\alpha & \\cos\\alpha
\\end{pmatrix}
The coordinates of the corner point A can be calculated as:
.. math::
P_A=
\\begin{pmatrix} x_A \\\\ y_A\\end{pmatrix}
=
\\begin{pmatrix} x_{center} \\\\ y_{center}\\end{pmatrix} +
\\begin{pmatrix}\\cos\\alpha & \\sin\\alpha \\\\
-\\sin\\alpha & \\cos\\alpha\\end{pmatrix}
\\begin{pmatrix} -0.5w \\\\ -0.5h\\end{pmatrix} \\\\
=
\\begin{pmatrix} x_{center}-0.5w\\cos\\alpha-0.5h\\sin\\alpha
\\\\
y_{center}+0.5w\\sin\\alpha-0.5h\\cos\\alpha\\end{pmatrix}
Args:
boxes1 (torch.Tensor): rotated bboxes 1. It has shape (N, 5),
indicating (x, y, w, h, theta) for each row. Note that theta is in
......@@ -23,6 +115,9 @@ def box_iou_rotated(bboxes1, bboxes2, mode='iou', aligned=False):
radian.
mode (str): "iou" (intersection over union) or iof (intersection over
foreground).
clockwise (bool): flag indicating whether the positive angular
orientation is clockwise. default True.
`New in version 1.4.3.`
Returns:
torch.Tensor: Return the ious betweens boxes. If ``aligned`` is
......@@ -37,6 +132,11 @@ def box_iou_rotated(bboxes1, bboxes2, mode='iou', aligned=False):
ious = bboxes1.new_zeros(rows)
else:
ious = bboxes1.new_zeros((rows * cols))
if not clockwise:
flip_mat = bboxes1.new_ones(bboxes1.shape[-1])
flip_mat[-1] = -1
bboxes1 = bboxes1 * flip_mat
bboxes2 = bboxes2 * flip_mat
bboxes1 = bboxes1.contiguous()
bboxes2 = bboxes2.contiguous()
ext_module.box_iou_rotated(
......
......@@ -392,7 +392,7 @@ def nms_match(dets, iou_threshold):
return [np.array(m, dtype=int) for m in matched]
def nms_rotated(dets, scores, iou_threshold, labels=None):
def nms_rotated(dets, scores, iou_threshold, labels=None, clockwise=True):
"""Performs non-maximum suppression (NMS) on the rotated boxes according to
their intersection-over-union (IoU).
......@@ -400,11 +400,14 @@ def nms_rotated(dets, scores, iou_threshold, labels=None):
IoU greater than iou_threshold with another (higher scoring) rotated box.
Args:
boxes (Tensor): Rotated boxes in shape (N, 5). They are expected to
dets (Tensor): Rotated boxes in shape (N, 5). They are expected to
be in (x_ctr, y_ctr, width, height, angle_radian) format.
scores (Tensor): scores in shape (N, ).
iou_threshold (float): IoU thresh for NMS.
labels (Tensor): boxes' label in shape (N,).
clockwise (bool): flag indicating whether the positive angular
orientation is clockwise. default True.
`New in version 1.4.3.`
Returns:
tuple: kept dets(boxes and scores) and indice, which is always the
......@@ -412,11 +415,17 @@ def nms_rotated(dets, scores, iou_threshold, labels=None):
"""
if dets.shape[0] == 0:
return dets, None
if not clockwise:
flip_mat = dets.new_ones(dets.shape[-1])
flip_mat[-1] = -1
dets_cw = dets * flip_mat
else:
dets_cw = dets
multi_label = labels is not None
if multi_label:
dets_wl = torch.cat((dets, labels.unsqueeze(1)), 1)
dets_wl = torch.cat((dets_cw, labels.unsqueeze(1)), 1)
else:
dets_wl = dets
dets_wl = dets_cw
_, order = scores.sort(0, descending=True)
dets_sorted = dets_wl.index_select(0, order)
......
......@@ -25,6 +25,7 @@ class TestBoxIoURotated(object):
boxes1 = torch.from_numpy(np_boxes1)
boxes2 = torch.from_numpy(np_boxes2)
# test cw angle definition
ious = box_iou_rotated(boxes1, boxes2)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
......@@ -32,6 +33,16 @@ class TestBoxIoURotated(object):
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
# test ccw angle definition
boxes1[..., -1] *= -1
boxes2[..., -1] *= -1
ious = box_iou_rotated(boxes1, boxes2, clockwise=False)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
ious = box_iou_rotated(boxes1, boxes2, aligned=True, clockwise=False)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_box_iou_rotated_cuda(self):
......@@ -54,6 +65,7 @@ class TestBoxIoURotated(object):
boxes1 = torch.from_numpy(np_boxes1).cuda()
boxes2 = torch.from_numpy(np_boxes2).cuda()
# test cw angle definition
ious = box_iou_rotated(boxes1, boxes2)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
......@@ -61,6 +73,16 @@ class TestBoxIoURotated(object):
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
# test ccw angle definition
boxes1[..., -1] *= -1
boxes2[..., -1] *= -1
ious = box_iou_rotated(boxes1, boxes2, clockwise=False)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
ious = box_iou_rotated(boxes1, boxes2, aligned=True, clockwise=False)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
def test_box_iou_rotated_iof_cpu(self):
from mmcv.ops import box_iou_rotated
np_boxes1 = np.asarray(
......@@ -81,12 +103,23 @@ class TestBoxIoURotated(object):
boxes1 = torch.from_numpy(np_boxes1)
boxes2 = torch.from_numpy(np_boxes2)
# test cw angle definition
ious = box_iou_rotated(boxes1, boxes2, mode='iof')
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
ious = box_iou_rotated(boxes1, boxes2, mode='iof', aligned=True)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
# test ccw angle definition
boxes1[..., -1] *= -1
boxes2[..., -1] *= -1
ious = box_iou_rotated(boxes1, boxes2, mode='iof', clockwise=False)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
ious = box_iou_rotated(
boxes1, boxes2, mode='iof', aligned=True, clockwise=False)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_box_iou_rotated_iof_cuda(self):
......@@ -109,9 +142,21 @@ class TestBoxIoURotated(object):
boxes1 = torch.from_numpy(np_boxes1).cuda()
boxes2 = torch.from_numpy(np_boxes2).cuda()
# test cw angle definition
ious = box_iou_rotated(boxes1, boxes2, mode='iof')
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
ious = box_iou_rotated(boxes1, boxes2, mode='iof', aligned=True)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
# test ccw angle definition
boxes1[..., -1] *= -1
boxes2[..., -1] *= -1
ious = box_iou_rotated(boxes1, boxes2, mode='iof', clockwise=False)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
ious = box_iou_rotated(
boxes1, boxes2, mode='iof', aligned=True, clockwise=False)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
......@@ -26,11 +26,20 @@ class TestNmsRotated:
boxes = torch.from_numpy(np_boxes).cuda()
labels = torch.from_numpy(np_labels).cuda()
# test cw angle definition
dets, keep_inds = nms_rotated(boxes[:, :5], boxes[:, -1], 0.5, labels)
assert np.allclose(dets.cpu().numpy()[:, :5], np_expect_dets)
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)
# test ccw angle definition
boxes[..., -2] *= -1
dets, keep_inds = nms_rotated(
boxes[:, :5], boxes[:, -1], 0.5, labels, clockwise=False)
dets[..., -2] *= -1
assert np.allclose(dets.cpu().numpy()[:, :5], np_expect_dets)
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)
def test_nms_rotated(self):
from mmcv.ops import nms_rotated
np_boxes = np.array(
......@@ -47,6 +56,15 @@ class TestNmsRotated:
boxes = torch.from_numpy(np_boxes).cuda()
# test cw angle definition
dets, keep_inds = nms_rotated(boxes[:, :5], boxes[:, -1], 0.5)
assert np.allclose(dets.cpu().numpy()[:, :5], np_expect_dets)
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)
# test ccw angle definition
boxes[..., -2] *= -1
dets, keep_inds = nms_rotated(
boxes[:, :5], boxes[:, -1], 0.5, clockwise=False)
dets[..., -2] *= -1
assert np.allclose(dets.cpu().numpy()[:, :5], np_expect_dets)
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment