Unverified Commit 43b2f098 authored by Shilong Zhang's avatar Shilong Zhang Committed by GitHub
Browse files

[Feature]Support skip nms (#1552)

* skip nms

* judge at beginning

* add test

* remove else

* add more details in docstr including version not

* fix unitest

* fix doc

* fix doc

* fix typo

* resove conversation

* fix link
parent 88e01733
...@@ -260,20 +260,25 @@ def soft_nms(boxes, ...@@ -260,20 +260,25 @@ def soft_nms(boxes,
def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
r"""Performs non-maximum suppression in a batched fashion. r"""Performs non-maximum suppression in a batched fashion.
Modified from https://github.com/pytorch/vision/blob\ Modified from
/505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39. https://github.com/pytorch/vision/blob/505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39.
In order to perform NMS independently per class, we add an offset to all In order to perform NMS independently per class, we add an offset to all
the boxes. The offset is dependent only on the class idx, and is large the boxes. The offset is dependent only on the class idx, and is large
enough so that boxes from different classes do not overlap. enough so that boxes from different classes do not overlap.
Note:
In v1.4.1 and later, ``batched_nms`` supports skipping the NMS and
returns sorted raw results when `nms_cfg` is None.
Args: Args:
boxes (torch.Tensor): boxes in shape (N, 4). boxes (torch.Tensor): boxes in shape (N, 4).
scores (torch.Tensor): scores in shape (N, ). scores (torch.Tensor): scores in shape (N, ).
idxs (torch.Tensor): each index value correspond to a bbox cluster, idxs (torch.Tensor): each index value correspond to a bbox cluster,
and NMS will not be applied between elements of different idxs, and NMS will not be applied between elements of different idxs,
shape (N, ). shape (N, ).
nms_cfg (dict): specify nms type and other parameters like iou_thr. nms_cfg (dict | None): Supports skipping the nms when `nms_cfg`
Possible keys includes the following. is None, otherwise it should specify nms type and other
parameters like `iou_thr`. Possible keys includes the following.
- iou_thr (float): IoU threshold used for NMS. - iou_thr (float): IoU threshold used for NMS.
- split_thr (float): threshold number of boxes. In some cases the - split_thr (float): threshold number of boxes. In some cases the
...@@ -288,7 +293,19 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): ...@@ -288,7 +293,19 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
Returns: Returns:
tuple: kept dets and indice. tuple: kept dets and indice.
- boxes (Tensor): Bboxes with score after nms, has shape
(num_bboxes, 5). last dimension 5 arrange as
(x1, y1, x2, y2, score)
- keep (Tensor): The indices of remaining boxes in input
boxes.
""" """
# skip nms when nms_cfg is None
if nms_cfg is None:
scores, inds = scores.sort(descending=True)
boxes = boxes[inds]
return torch.cat([boxes, scores[:, None]], -1), inds
nms_cfg_ = nms_cfg.copy() nms_cfg_ = nms_cfg.copy()
class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic) class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic)
if class_agnostic: if class_agnostic:
...@@ -333,7 +350,8 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): ...@@ -333,7 +350,8 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
boxes = boxes[:max_num] boxes = boxes[:max_num]
scores = scores[:max_num] scores = scores[:max_num]
return torch.cat([boxes, scores[:, None]], -1), keep boxes = torch.cat([boxes, scores[:, None]], -1)
return boxes, keep
def nms_match(dets, iou_threshold): def nms_match(dets, iou_threshold):
......
...@@ -182,3 +182,14 @@ class Testnms(object): ...@@ -182,3 +182,14 @@ class Testnms(object):
assert torch.equal(keep, seq_keep) assert torch.equal(keep, seq_keep)
assert torch.equal(boxes, seq_boxes) assert torch.equal(boxes, seq_boxes)
# test skip nms when `nms_cfg` is None
seq_boxes, seq_keep = batched_nms(
torch.from_numpy(results['boxes']),
torch.from_numpy(results['scores']),
torch.from_numpy(results['idxs']),
None,
class_agnostic=False)
assert len(seq_keep) == len(results['boxes'])
# assert score is descending order
assert ((seq_boxes[:, -1][1:] - seq_boxes[:, -1][:-1]) < 0).all()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment