Unverified Commit 414427dd authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

[OPS, IMP] New batched_nms implementation (#3426)



* new batched_nms implem

* flake8

* hopefully fix torchscipt tests

* Use where instead of nonzero

* Use same threshold (4k) for CPU and GPU

* Remove use of argsort

* use views again

* remove print

* trying stuff, I don't know what's going on

* previous passed onnx checks so the error isn't in _vanilla func. Trying to return vanilla now

* add tracing decorators

* cleanup

* wip

* ignore new path with ONNX

* use vanilla if tracing...????

* Remove script_if_tracing decorator as it was conflicting with _is_tracing

* flake8

* Improve coverage
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent c991db82
# onnxruntime requires python 3.5 or above
try:
# This import should be before that of torch
# see https://github.com/onnx/onnx/issues/2394#issuecomment-581638840
import onnxruntime
except ImportError:
onnxruntime = None
from common_utils import set_rng_seed from common_utils import set_rng_seed
import io import io
import torch import torch
...@@ -13,12 +21,6 @@ from torchvision.models.detection.mask_rcnn import MaskRCNNHeads, MaskRCNNPredic ...@@ -13,12 +21,6 @@ from torchvision.models.detection.mask_rcnn import MaskRCNNHeads, MaskRCNNPredic
from collections import OrderedDict from collections import OrderedDict
# onnxruntime requires python 3.5 or above
try:
import onnxruntime
except ImportError:
onnxruntime = None
import unittest import unittest
from torchvision.ops._register_onnx_ops import _onnx_opset_version from torchvision.ops._register_onnx_ops import _onnx_opset_version
......
...@@ -461,6 +461,28 @@ class NMSTester(unittest.TestCase): ...@@ -461,6 +461,28 @@ class NMSTester(unittest.TestCase):
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres) keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
self.assertTrue(torch.all(torch.eq(keep32, keep16))) self.assertTrue(torch.all(torch.eq(keep32, keep16)))
def test_batched_nms_implementations(self):
"""Make sure that both implementations of batched_nms yield identical results"""
num_boxes = 1000
iou_threshold = .9
boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2
assert max(boxes[:, 1]) < min(boxes[:, 3]) # y1 < y2
scores = torch.rand(num_boxes)
idxs = torch.randint(0, 4, size=(num_boxes,))
keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
err_msg = "The vanilla and the trick implementation yield different nms outputs."
self.assertTrue(torch.allclose(keep_vanilla, keep_trick), err_msg)
# Also make sure an empty tensor is returned if boxes is empty
empty = torch.empty((0,), dtype=torch.int64)
self.assertTrue(torch.allclose(empty, ops.batched_nms(empty, None, None, None)))
class DeformConvTester(OpTester, unittest.TestCase): class DeformConvTester(OpTester, unittest.TestCase):
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1): def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
......
import torch import torch
from torch import Tensor from torch import Tensor
from typing import Tuple from typing import List, Tuple
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
import torchvision import torchvision
from torchvision.extension import _assert_has_ops from torchvision.extension import _assert_has_ops
...@@ -36,7 +36,6 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: ...@@ -36,7 +36,6 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
return torch.ops.torchvision.nms(boxes, scores, iou_threshold) return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
@torch.jit._script_if_tracing
def batched_nms( def batched_nms(
boxes: Tensor, boxes: Tensor,
scores: Tensor, scores: Tensor,
...@@ -62,18 +61,50 @@ def batched_nms( ...@@ -62,18 +61,50 @@ def batched_nms(
the elements that have been kept by NMS, sorted the elements that have been kept by NMS, sorted
in decreasing order of scores in decreasing order of scores
""" """
if boxes.numel() == 0: # Benchmarks that drove the following thresholds are at
return torch.empty((0,), dtype=torch.int64, device=boxes.device) # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
# strategy: in order to perform NMS independently per class. # Ideally for GPU we'd use a higher threshold
if boxes.numel() > 4_000 and not torchvision._is_tracing():
return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
else:
return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
@torch.jit._script_if_tracing
def _batched_nms_coordinate_trick(
boxes: Tensor,
scores: Tensor,
idxs: Tensor,
iou_threshold: float,
) -> Tensor:
# strategy: in order to perform NMS independently per class,
# we add an offset to all the boxes. The offset is dependent # we add an offset to all the boxes. The offset is dependent
# only on the class idx, and is large enough so that boxes # only on the class idx, and is large enough so that boxes
# from different classes do not overlap # from different classes do not overlap
else: if boxes.numel() == 0:
max_coordinate = boxes.max() return torch.empty((0,), dtype=torch.int64, device=boxes.device)
offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) max_coordinate = boxes.max()
boxes_for_nms = boxes + offsets[:, None] offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
keep = nms(boxes_for_nms, scores, iou_threshold) boxes_for_nms = boxes + offsets[:, None]
return keep keep = nms(boxes_for_nms, scores, iou_threshold)
return keep
@torch.jit._script_if_tracing
def _batched_nms_vanilla(
boxes: Tensor,
scores: Tensor,
idxs: Tensor,
iou_threshold: float,
) -> Tensor:
# Based on Detectron2 implementation, just manually call nms() on each class independently
keep_mask = torch.zeros_like(scores, dtype=torch.bool)
for class_id in torch.unique(idxs):
curr_indices = torch.where(idxs == class_id)[0]
curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
keep_mask[curr_indices[curr_keep_indices]] = True
keep_indices = torch.where(keep_mask)[0]
return keep_indices[scores[keep_indices].sort(descending=True)[1]]
def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor: def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment