Unverified Commit 414427dd authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

[OPS, IMP] New batched_nms implementation (#3426)



* new batched_nms implem

* flake8

* hopefully fix torchscipt tests

* Use where instead of nonzero

* Use same threshold (4k) for CPU and GPU

* Remove use of argsort

* use views again

* remove print

* trying stuff, I don't know what's going on

* previous passed onnx checks so the error isn't in _vanilla func. Trying to return vanilla now

* add tracing decorators

* cleanup

* wip

* ignore new path with ONNX

* use vanilla if tracing...????

* Remove script_if_tracing decorator as it was conflicting with _is_tracing

* flake8

* Improve coverage
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent c991db82
# onnxruntime requires python 3.5 or above
try:
# This import should be before that of torch
# see https://github.com/onnx/onnx/issues/2394#issuecomment-581638840
import onnxruntime
except ImportError:
onnxruntime = None
from common_utils import set_rng_seed
import io
import torch
......@@ -13,12 +21,6 @@ from torchvision.models.detection.mask_rcnn import MaskRCNNHeads, MaskRCNNPredic
from collections import OrderedDict
# onnxruntime requires python 3.5 or above
try:
import onnxruntime
except ImportError:
onnxruntime = None
import unittest
from torchvision.ops._register_onnx_ops import _onnx_opset_version
......
......@@ -461,6 +461,28 @@ class NMSTester(unittest.TestCase):
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
self.assertTrue(torch.all(torch.eq(keep32, keep16)))
def test_batched_nms_implementations(self):
"""Make sure that both implementations of batched_nms yield identical results"""
num_boxes = 1000
iou_threshold = .9
boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2
assert max(boxes[:, 1]) < min(boxes[:, 3]) # y1 < y2
scores = torch.rand(num_boxes)
idxs = torch.randint(0, 4, size=(num_boxes,))
keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
err_msg = "The vanilla and the trick implementation yield different nms outputs."
self.assertTrue(torch.allclose(keep_vanilla, keep_trick), err_msg)
# Also make sure an empty tensor is returned if boxes is empty
empty = torch.empty((0,), dtype=torch.int64)
self.assertTrue(torch.allclose(empty, ops.batched_nms(empty, None, None, None)))
class DeformConvTester(OpTester, unittest.TestCase):
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
......
import torch
from torch import Tensor
from typing import Tuple
from typing import List, Tuple
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
import torchvision
from torchvision.extension import _assert_has_ops
......@@ -36,7 +36,6 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
@torch.jit._script_if_tracing
def batched_nms(
boxes: Tensor,
scores: Tensor,
......@@ -62,13 +61,28 @@ def batched_nms(
the elements that have been kept by NMS, sorted
in decreasing order of scores
"""
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
# strategy: in order to perform NMS independently per class.
# Benchmarks that drove the following thresholds are at
# https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
# Ideally for GPU we'd use a higher threshold
if boxes.numel() > 4_000 and not torchvision._is_tracing():
return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
else:
return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
@torch.jit._script_if_tracing
def _batched_nms_coordinate_trick(
boxes: Tensor,
scores: Tensor,
idxs: Tensor,
iou_threshold: float,
) -> Tensor:
# strategy: in order to perform NMS independently per class,
# we add an offset to all the boxes. The offset is dependent
# only on the class idx, and is large enough so that boxes
# from different classes do not overlap
else:
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
boxes_for_nms = boxes + offsets[:, None]
......@@ -76,6 +90,23 @@ def batched_nms(
return keep
@torch.jit._script_if_tracing
def _batched_nms_vanilla(
boxes: Tensor,
scores: Tensor,
idxs: Tensor,
iou_threshold: float,
) -> Tensor:
# Based on Detectron2 implementation, just manually call nms() on each class independently
keep_mask = torch.zeros_like(scores, dtype=torch.bool)
for class_id in torch.unique(idxs):
curr_indices = torch.where(idxs == class_id)[0]
curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
keep_mask[curr_indices[curr_keep_indices]] = True
keep_indices = torch.where(keep_mask)[0]
return keep_indices[scores[keep_indices].sort(descending=True)[1]]
def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
"""
Remove boxes which contains at least one side smaller than min_size.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment