Unverified Commit 9807c2d2 authored by Yanyi Liu's avatar Yanyi Liu Committed by GitHub
Browse files

[Fix] Fix batched_nms for rotated box and add type hints for nms.py (#2006)

* Fix batched_nms for rotated box.
Add type hint for nms.py

* Add test

* doc string

* revert symbolic hint

* fix max_coordinate

* add comment

* rename type

* fix typo docstring
parent 5a2906cb
import os import os
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from torch import Tensor
from mmcv.utils import deprecated_api_warning from mmcv.utils import deprecated_api_warning
from ..utils import ext_loader from ..utils import ext_loader
...@@ -14,8 +16,8 @@ ext_module = ext_loader.load_ext( ...@@ -14,8 +16,8 @@ ext_module = ext_loader.load_ext(
class NMSop(torch.autograd.Function): class NMSop(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, bboxes, scores, iou_threshold, offset, score_threshold, def forward(ctx: Any, bboxes: Tensor, scores: Tensor, iou_threshold: float,
max_num): offset: int, score_threshold: float, max_num: int) -> Tensor:
is_filtering_by_score = score_threshold > 0 is_filtering_by_score = score_threshold > 0
if is_filtering_by_score: if is_filtering_by_score:
valid_mask = scores > score_threshold valid_mask = scores > score_threshold
...@@ -83,8 +85,9 @@ class NMSop(torch.autograd.Function): ...@@ -83,8 +85,9 @@ class NMSop(torch.autograd.Function):
class SoftNMSop(torch.autograd.Function): class SoftNMSop(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, boxes, scores, iou_threshold, sigma, min_score, method, def forward(ctx: Any, boxes: Tensor, scores: Tensor, iou_threshold: float,
offset): sigma: float, min_score: float, method: int,
offset: int) -> Tuple[Tensor, Tensor]:
dets = boxes.new_empty((boxes.size(0), 5), device='cpu') dets = boxes.new_empty((boxes.size(0), 5), device='cpu')
inds = ext_module.softnms( inds = ext_module.softnms(
boxes.cpu(), boxes.cpu(),
...@@ -115,8 +118,16 @@ class SoftNMSop(torch.autograd.Function): ...@@ -115,8 +118,16 @@ class SoftNMSop(torch.autograd.Function):
return nms_out return nms_out
array_like_type = Union[Tensor, np.ndarray]
@deprecated_api_warning({'iou_thr': 'iou_threshold'}) @deprecated_api_warning({'iou_thr': 'iou_threshold'})
def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1): def nms(boxes: array_like_type,
scores: array_like_type,
iou_threshold: float,
offset: int = 0,
score_threshold: float = 0,
max_num: int = -1) -> Tuple[array_like_type, array_like_type]:
"""Dispatch to either CPU or GPU NMS implementations. """Dispatch to either CPU or GPU NMS implementations.
The input can be either torch tensor or numpy array. GPU NMS will be used The input can be either torch tensor or numpy array. GPU NMS will be used
...@@ -149,8 +160,8 @@ def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1): ...@@ -149,8 +160,8 @@ def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1):
>>> dets, inds = nms(boxes, scores, iou_threshold) >>> dets, inds = nms(boxes, scores, iou_threshold)
>>> assert len(inds) == len(dets) == 3 >>> assert len(inds) == len(dets) == 3
""" """
assert isinstance(boxes, (torch.Tensor, np.ndarray)) assert isinstance(boxes, (Tensor, np.ndarray))
assert isinstance(scores, (torch.Tensor, np.ndarray)) assert isinstance(scores, (Tensor, np.ndarray))
is_numpy = False is_numpy = False
if isinstance(boxes, np.ndarray): if isinstance(boxes, np.ndarray):
is_numpy = True is_numpy = True
...@@ -171,13 +182,13 @@ def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1): ...@@ -171,13 +182,13 @@ def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1):
@deprecated_api_warning({'iou_thr': 'iou_threshold'}) @deprecated_api_warning({'iou_thr': 'iou_threshold'})
def soft_nms(boxes, def soft_nms(boxes: array_like_type,
scores, scores: array_like_type,
iou_threshold=0.3, iou_threshold: float = 0.3,
sigma=0.5, sigma: float = 0.5,
min_score=1e-3, min_score: float = 1e-3,
method='linear', method: str = 'linear',
offset=0): offset: int = 0) -> Tuple[array_like_type, array_like_type]:
"""Dispatch to only CPU Soft NMS implementations. """Dispatch to only CPU Soft NMS implementations.
The input can be either a torch tensor or numpy array. The input can be either a torch tensor or numpy array.
...@@ -209,8 +220,8 @@ def soft_nms(boxes, ...@@ -209,8 +220,8 @@ def soft_nms(boxes,
>>> assert len(inds) == len(dets) == 5 >>> assert len(inds) == len(dets) == 5
""" """
assert isinstance(boxes, (torch.Tensor, np.ndarray)) assert isinstance(boxes, (Tensor, np.ndarray))
assert isinstance(scores, (torch.Tensor, np.ndarray)) assert isinstance(scores, (Tensor, np.ndarray))
is_numpy = False is_numpy = False
if isinstance(boxes, np.ndarray): if isinstance(boxes, np.ndarray):
is_numpy = True is_numpy = True
...@@ -250,7 +261,11 @@ def soft_nms(boxes, ...@@ -250,7 +261,11 @@ def soft_nms(boxes,
return dets.to(device=boxes.device), inds.to(device=boxes.device) return dets.to(device=boxes.device), inds.to(device=boxes.device)
def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): def batched_nms(boxes: Tensor,
scores: Tensor,
idxs: Tensor,
nms_cfg: Optional[Dict],
class_agnostic: bool = False) -> Tuple[Tensor, Tensor]:
r"""Performs non-maximum suppression in a batched fashion. r"""Performs non-maximum suppression in a batched fashion.
Modified from `torchvision/ops/boxes.py#L39 Modified from `torchvision/ops/boxes.py#L39
...@@ -265,16 +280,16 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): ...@@ -265,16 +280,16 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
returns sorted raw results when `nms_cfg` is None. returns sorted raw results when `nms_cfg` is None.
Args: Args:
boxes (torch.Tensor): boxes in shape (N, 4). boxes (torch.Tensor): boxes in shape (N, 4) or (N, 5).
scores (torch.Tensor): scores in shape (N, ). scores (torch.Tensor): scores in shape (N, ).
idxs (torch.Tensor): each index value correspond to a bbox cluster, idxs (torch.Tensor): each index value correspond to a bbox cluster,
and NMS will not be applied between elements of different idxs, and NMS will not be applied between elements of different idxs,
shape (N, ). shape (N, ).
nms_cfg (dict | None): Supports skipping the nms when `nms_cfg` nms_cfg (dict | optional): Supports skipping the nms when `nms_cfg`
is None, otherwise it should specify nms type and other is None, otherwise it should specify nms type and other
parameters like `iou_thr`. Possible keys includes the following. parameters like `iou_thr`. Possible keys includes the following.
- iou_thr (float): IoU threshold used for NMS. - iou_threshold (float): IoU threshold used for NMS.
- split_thr (float): threshold number of boxes. In some cases the - split_thr (float): threshold number of boxes. In some cases the
number of boxes is large (e.g., 200k). To avoid OOM during number of boxes is large (e.g., 200k). To avoid OOM during
training, the users could set `split_thr` to a small value. training, the users could set `split_thr` to a small value.
...@@ -283,7 +298,7 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): ...@@ -283,7 +298,7 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
Defaults to 10000. Defaults to 10000.
class_agnostic (bool): if true, nms is class agnostic, class_agnostic (bool): if true, nms is class agnostic,
i.e. IoU thresholding happens over all boxes, i.e. IoU thresholding happens over all boxes,
regardless of the predicted class. regardless of the predicted class. Defaults to False.
Returns: Returns:
tuple: kept dets and indice. tuple: kept dets and indice.
...@@ -305,9 +320,26 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): ...@@ -305,9 +320,26 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
if class_agnostic: if class_agnostic:
boxes_for_nms = boxes boxes_for_nms = boxes
else: else:
max_coordinate = boxes.max() # When using rotated boxes, only apply offsets on center.
offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) if boxes.size(-1) == 5:
boxes_for_nms = boxes + offsets[:, None] # Strictly, the maximum coordinates of the rotating box
# (x,y,w,h,a) should be calculated by polygon coordinates.
# But the conversion from rotated box to polygon will
# slow down the speed.
# So we use max(x,y) + max(w,h) as max coordinate
# which is larger than polygon max coordinate
# max(x1, y1, x2, y2,x3, y3, x4, y4)
max_coordinate = boxes[..., :2].max() + boxes[..., 2:4].max()
offsets = idxs.to(boxes) * (
max_coordinate + torch.tensor(1).to(boxes))
boxes_ctr_for_nms = boxes[..., :2] + offsets[:, None]
boxes_for_nms = torch.cat([boxes_ctr_for_nms, boxes[..., 2:5]],
dim=-1)
else:
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (
max_coordinate + torch.tensor(1).to(boxes))
boxes_for_nms = boxes + offsets[:, None]
nms_type = nms_cfg_.pop('type', 'nms') nms_type = nms_cfg_.pop('type', 'nms')
nms_op = eval(nms_type) nms_op = eval(nms_type)
...@@ -349,7 +381,8 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): ...@@ -349,7 +381,8 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
return boxes, keep return boxes, keep
def nms_match(dets, iou_threshold): def nms_match(dets: array_like_type,
iou_threshold: float) -> List[array_like_type]:
"""Matched dets into different groups by NMS. """Matched dets into different groups by NMS.
NMS match is Similar to NMS but when a bbox is suppressed, nms match will NMS match is Similar to NMS but when a bbox is suppressed, nms match will
...@@ -358,7 +391,7 @@ def nms_match(dets, iou_threshold): ...@@ -358,7 +391,7 @@ def nms_match(dets, iou_threshold):
Args: Args:
dets (torch.Tensor | np.ndarray): Det boxes with scores, shape (N, 5). dets (torch.Tensor | np.ndarray): Det boxes with scores, shape (N, 5).
iou_thr (float): IoU thresh for NMS. iou_threshold (float): IoU thresh for NMS.
Returns: Returns:
list[torch.Tensor | np.ndarray]: The outer list corresponds different list[torch.Tensor | np.ndarray]: The outer list corresponds different
...@@ -370,7 +403,7 @@ def nms_match(dets, iou_threshold): ...@@ -370,7 +403,7 @@ def nms_match(dets, iou_threshold):
else: else:
assert dets.shape[-1] == 5, 'inputs dets.shape should be (N, 5), ' \ assert dets.shape[-1] == 5, 'inputs dets.shape should be (N, 5), ' \
f'but get {dets.shape}' f'but get {dets.shape}'
if isinstance(dets, torch.Tensor): if isinstance(dets, Tensor):
dets_t = dets.detach().cpu() dets_t = dets.detach().cpu()
else: else:
dets_t = torch.from_numpy(dets) dets_t = torch.from_numpy(dets)
...@@ -378,15 +411,19 @@ def nms_match(dets, iou_threshold): ...@@ -378,15 +411,19 @@ def nms_match(dets, iou_threshold):
indata_dict = {'iou_threshold': float(iou_threshold)} indata_dict = {'iou_threshold': float(iou_threshold)}
matched = ext_module.nms_match(*indata_list, **indata_dict) matched = ext_module.nms_match(*indata_list, **indata_dict)
if torch.__version__ == 'parrots': if torch.__version__ == 'parrots':
matched = matched.tolist() matched = matched.tolist() # type: ignore
if isinstance(dets, torch.Tensor): if isinstance(dets, Tensor):
return [dets.new_tensor(m, dtype=torch.long) for m in matched] return [dets.new_tensor(m, dtype=torch.long) for m in matched]
else: else:
return [np.array(m, dtype=int) for m in matched] return [np.array(m, dtype=int) for m in matched]
def nms_rotated(dets, scores, iou_threshold, labels=None, clockwise=True): def nms_rotated(dets: Tensor,
scores: Tensor,
iou_threshold: float,
labels: Optional[Tensor] = None,
clockwise: bool = True) -> Tuple[Tensor, Tensor]:
"""Performs non-maximum suppression (NMS) on the rotated boxes according to """Performs non-maximum suppression (NMS) on the rotated boxes according to
their intersection-over-union (IoU). their intersection-over-union (IoU).
...@@ -394,11 +431,12 @@ def nms_rotated(dets, scores, iou_threshold, labels=None, clockwise=True): ...@@ -394,11 +431,12 @@ def nms_rotated(dets, scores, iou_threshold, labels=None, clockwise=True):
IoU greater than iou_threshold with another (higher scoring) rotated box. IoU greater than iou_threshold with another (higher scoring) rotated box.
Args: Args:
dets (Tensor): Rotated boxes in shape (N, 5). They are expected to dets (torch.Tensor): Rotated boxes in shape (N, 5).
be in (x_ctr, y_ctr, width, height, angle_radian) format. They are expected to be in
scores (Tensor): scores in shape (N, ). (x_ctr, y_ctr, width, height, angle_radian) format.
scores (torch.Tensor): scores in shape (N, ).
iou_threshold (float): IoU thresh for NMS. iou_threshold (float): IoU thresh for NMS.
labels (Tensor): boxes' label in shape (N,). labels (torch.Tensor, optional): boxes' label in shape (N,).
clockwise (bool): flag indicating whether the positive angular clockwise (bool): flag indicating whether the positive angular
orientation is clockwise. default True. orientation is clockwise. default True.
`New in version 1.4.3.` `New in version 1.4.3.`
...@@ -417,7 +455,7 @@ def nms_rotated(dets, scores, iou_threshold, labels=None, clockwise=True): ...@@ -417,7 +455,7 @@ def nms_rotated(dets, scores, iou_threshold, labels=None, clockwise=True):
dets_cw = dets dets_cw = dets
multi_label = labels is not None multi_label = labels is not None
if multi_label: if multi_label:
dets_wl = torch.cat((dets_cw, labels.unsqueeze(1)), 1) dets_wl = torch.cat((dets_cw, labels.unsqueeze(1)), 1) # type: ignore
else: else:
dets_wl = dets_cw dets_wl = dets_cw
_, order = scores.sort(0, descending=True) _, order = scores.sort(0, descending=True)
......
...@@ -70,15 +70,46 @@ class TestNmsRotated: ...@@ -70,15 +70,46 @@ class TestNmsRotated:
assert np.allclose(dets.cpu().numpy()[:, :5], np_expect_dets) assert np.allclose(dets.cpu().numpy()[:, :5], np_expect_dets)
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds) assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)
def test_batched_nms(self):
# test batched_nms with nms_rotated # test batched_nms with nms_rotated
from mmcv.ops import batched_nms from mmcv.ops import batched_nms
np_boxes = np.array(
[[6.0, 3.0, 8.0, 7.0, 0.5, 0.7], [3.0, 6.0, 9.0, 11.0, 0.6, 0.8],
[3.0, 7.0, 10.0, 12.0, 0.3, 0.5], [1.0, 4.0, 13.0, 7.0, 0.6, 0.9]
],
dtype=np.float32)
np_labels = np.array([1, 0, 1, 0], dtype=np.float32)
np_expect_agnostic_dets = np.array(
[[1.0, 4.0, 13.0, 7.0, 0.6], [3.0, 6.0, 9.0, 11.0, 0.6],
[6.0, 3.0, 8.0, 7.0, 0.5]],
dtype=np.float32)
np_expect_agnostic_keep_inds = np.array([3, 1, 0], dtype=np.int64)
np_expect_dets = np.array(
[[1.0, 4.0, 13.0, 7.0, 0.6], [3.0, 6.0, 9.0, 11.0, 0.6],
[6.0, 3.0, 8.0, 7.0, 0.5], [3.0, 7.0, 10.0, 12.0, 0.3]],
dtype=np.float32)
np_expect_keep_inds = np.array([3, 1, 0, 2], dtype=np.int64)
nms_cfg = dict(type='nms_rotated', iou_threshold=0.5) nms_cfg = dict(type='nms_rotated', iou_threshold=0.5)
# test class_agnostic is True
boxes, keep = batched_nms(
torch.from_numpy(np_boxes[:, :5]),
torch.from_numpy(np_boxes[:, -1]),
torch.from_numpy(np_labels),
nms_cfg,
class_agnostic=True)
assert np.allclose(boxes.cpu().numpy()[:, :5], np_expect_agnostic_dets)
assert np.allclose(keep.cpu().numpy(), np_expect_agnostic_keep_inds)
# test class_agnostic is False
boxes, keep = batched_nms( boxes, keep = batched_nms(
torch.from_numpy(np_boxes[:, :5]), torch.from_numpy(np_boxes[:, :5]),
torch.from_numpy(np_boxes[:, -1]), torch.from_numpy(np_boxes[:, -1]),
torch.from_numpy(np.array([0, 0, 0, 0])), torch.from_numpy(np_labels),
nms_cfg, nms_cfg,
class_agnostic=False) class_agnostic=False)
assert np.allclose(boxes.cpu().numpy()[:, :5], np_expect_dets) assert np.allclose(boxes.cpu().numpy()[:, :5], np_expect_dets)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment