Unverified Commit f4a5446e authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Support to split batched_nms when box number is too large (#516)

* Support to split batched_nms when box number is too large

* mv data from gpu to cpu

* Set split_thr through nms_cfg

* clean code

* Update motivation in docstring

* fix typos
parent 83d9a9c8
......@@ -223,9 +223,18 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
and NMS will not be applied between elements of different idxs,
shape (N, ).
nms_cfg (dict): specify nms type and other parameters like iou_thr.
Possible keys includes the following.
- iou_thr (float): IoU threshold used for NMS.
- split_thr (float): threshold number of boxes. In some cases the
number of boxes is large (e.g., 200k). To avoid OOM during
training, the users could set `split_thr` to a small value.
If the number of boxes is greater than the threshold, it will
perform NMS on each group of boxes separately and sequentially.
Defaults to 10000.
class_agnostic (bool): if true, nms is class agnostic,
i.e. IoU thresholding happens over all boxes,
regardless of the predicted class
regardless of the predicted class.
Returns:
tuple: kept dets and indice.
......@@ -238,11 +247,27 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + 1)
boxes_for_nms = boxes + offsets[:, None]
nms_type = nms_cfg_.pop('type', 'nms')
nms_op = eval(nms_type)
dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_)
boxes = boxes[keep]
scores = dets[:, -1]
split_thr = nms_cfg_.pop('split_thr', 10000)
if len(boxes_for_nms) < split_thr:
dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_)
boxes = boxes[keep]
scores = dets[:, -1]
else:
total_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
for id in torch.unique(idxs):
mask = (idxs == id).nonzero(as_tuple=False).view(-1)
dets, keep = nms_op(boxes_for_nms[mask], scores[mask], **nms_cfg_)
total_mask[mask[keep]] = True
keep = total_mask.nonzero(as_tuple=False).view(-1)
keep = keep[scores[keep].argsort(descending=True)]
boxes = boxes[keep]
scores = scores[keep]
return torch.cat([boxes, scores[:, None]], -1), keep
......
......@@ -132,3 +132,28 @@ class Testnms(object):
wrong_dets = np.zeros((2, 3))
with pytest.raises(AssertionError):
nms_match(wrong_dets, iou_thr)
def test_batched_nms(self):
import mmcv
from mmcv.ops import batched_nms
results = mmcv.load('./tests/data/batched_nms_data.pkl')
nms_cfg = dict(type='nms', iou_threshold=0.7)
boxes, keep = batched_nms(
results['boxes'],
results['scores'],
results['idxs'],
nms_cfg,
class_agnostic=False)
nms_cfg.update(split_thr=100)
seq_boxes, seq_keep = batched_nms(
results['boxes'],
results['scores'],
results['idxs'],
nms_cfg,
class_agnostic=False)
assert torch.equal(keep, seq_keep)
assert torch.equal(boxes, seq_boxes)
assert torch.equal(keep, results['keep'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment