Unverified Commit 4897402a authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix inconsistent NMS implementation between CPU and CUDA (#1556)

* Fix inconsistent NMS implementation

* Improve tests for NMS

* Remove unnecessary using statement
parent 8909ff43
......@@ -1196,26 +1196,33 @@ class NMSTester(unittest.TestCase):
return torch.as_tensor(picked)
def _create_tensors(self, N):
def _create_tensors_with_iou(self, N, iou_thresh):
# force last box to have a pre-defined iou with the first box
# let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
# then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
# we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
boxes = torch.rand(N, 4) * 100
boxes[:, 2:] += torch.rand(N, 2) * 100
boxes[:, 2:] += boxes[:, :2]
boxes[-1, :] = boxes[0, :]
x0, y0, x1, y1 = boxes[-1].tolist()
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
scores = torch.rand(N)
return boxes, scores
def test_nms(self):
boxes, scores = self._create_tensors(1000)
err_msg = 'NMS incompatible between CPU and reference implementation for IoU={}'
for iou in [0.2, 0.5, 0.8]:
boxes, scores = self._create_tensors_with_iou(1000, iou)
keep_ref = self.reference_nms(boxes, scores, iou)
keep = ops.nms(boxes, scores, iou)
self.assertTrue(torch.allclose(keep, keep_ref), err_msg.format(iou))
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_nms_cuda(self):
boxes, scores = self._create_tensors(1000)
err_msg = 'NMS incompatible between CPU and CUDA for IoU={}'
for iou in [0.2, 0.5, 0.8]:
boxes, scores = self._create_tensors_with_iou(1000, iou)
r_cpu = ops.nms(boxes, scores, iou)
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)
......
......@@ -61,7 +61,7 @@ at::Tensor nms_cpu_kernel(
auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1);
auto inter = w * h;
auto ovr = inter / (iarea + areas[j] - inter);
if (ovr >= iou_threshold)
if (ovr > iou_threshold)
suppressed[j] = 1;
}
}
......
......@@ -72,7 +72,6 @@ __global__ void nms_kernel(
at::Tensor nms_cuda(const at::Tensor& dets,
const at::Tensor& scores,
float iou_threshold) {
using scalar_t = float;
AT_ASSERTM(dets.type().is_cuda(), "dets must be a CUDA tensor");
AT_ASSERTM(scores.type().is_cuda(), "scores must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(dets.device());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment