Commit e1f6ef2c authored by Jon Crall's avatar Jon Crall Committed by Kai Chen
Browse files

Add tests for NMS / Issue with NMS on GPU 1. (#1603)

* Add tests for NMS

* Fix linting errors

* Add DeviceGaurd to nms_cuda
parent 6304b647
......@@ -21,6 +21,18 @@ def nms(dets, iou_thr, device_id=None):
Returns:
tuple: kept bboxes and indice, which is always the same data type as
the input.
Example:
>>> dets = np.array([[49.1, 32.4, 51.0, 35.9, 0.9],
>>> [49.3, 32.9, 51.0, 35.3, 0.9],
>>> [49.2, 31.8, 51.0, 35.4, 0.5],
>>> [35.1, 11.5, 39.1, 15.7, 0.5],
>>> [35.6, 11.8, 39.3, 14.2, 0.5],
>>> [35.3, 11.5, 39.9, 14.5, 0.4],
>>> [35.2, 11.7, 39.7, 15.7, 0.3]], dtype=np.float32)
>>> iou_thr = 0.7
>>> supressed, inds = nms(dets, iou_thr)
>>> assert len(inds) == len(supressed) == 3
"""
# convert dets (tensor or numpy array) to tensor
if isinstance(dets, torch.Tensor):
......@@ -50,6 +62,18 @@ def nms(dets, iou_thr, device_id=None):
def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
"""
Example:
>>> dets = np.array([[4., 3., 5., 3., 0.9],
>>> [4., 3., 5., 4., 0.9],
>>> [3., 1., 3., 1., 0.5],
>>> [3., 1., 3., 1., 0.5],
>>> [3., 1., 3., 1., 0.4],
>>> [3., 1., 3., 1., 0.0]], dtype=np.float32)
>>> iou_thr = 0.7
>>> supressed, inds = soft_nms(dets, iou_thr, sigma=0.5)
>>> assert len(inds) == len(supressed) == 3
"""
if isinstance(dets, torch.Tensor):
is_tensor = True
dets_np = dets.detach().cpu().numpy()
......
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/DeviceGuard.h>
#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>
......@@ -68,6 +69,10 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
// boxes is a N x 5 tensor
at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
// Ensure CUDA uses the input tensor device.
at::DeviceGuard guard(boxes.device());
using scalar_t = float;
AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor");
auto scores = boxes.select(1, 4);
......@@ -128,4 +133,4 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
order_t.device(), keep.scalar_type())
}).sort(0, false));
}
\ No newline at end of file
}
"""
CommandLine:
pytest tests/test_nms.py
"""
import numpy as np
import torch
from mmdet.ops.nms.nms_wrapper import nms
def test_nms_device_and_dtypes_cpu():
"""
CommandLine:
xdoctest -m tests/test_nms.py test_nms_device_and_dtypes_cpu
"""
iou_thr = 0.7
base_dets = np.array([[49.1, 32.4, 51.0, 35.9, 0.9],
[49.3, 32.9, 51.0, 35.3, 0.9],
[35.3, 11.5, 39.9, 14.5, 0.4],
[35.2, 11.7, 39.7, 15.7, 0.3]])
# CPU can handle float32 and float64
dets = base_dets.astype(np.float32)
supressed, inds = nms(dets, iou_thr)
assert dets.dtype == supressed.dtype
assert len(inds) == len(supressed) == 3
dets = torch.FloatTensor(base_dets)
surpressed, inds = nms(dets, iou_thr)
assert dets.dtype == surpressed.dtype
assert len(inds) == len(surpressed) == 3
dets = base_dets.astype(np.float64)
supressed, inds = nms(dets, iou_thr)
assert dets.dtype == supressed.dtype
assert len(inds) == len(supressed) == 3
dets = torch.DoubleTensor(base_dets)
surpressed, inds = nms(dets, iou_thr)
assert dets.dtype == surpressed.dtype
assert len(inds) == len(surpressed) == 3
def test_nms_device_and_dtypes_gpu():
"""
CommandLine:
xdoctest -m tests/test_nms.py test_nms_device_and_dtypes_gpu
"""
if not torch.cuda.is_available():
import pytest
pytest.skip('test requires GPU and torch+cuda')
iou_thr = 0.7
base_dets = np.array([[49.1, 32.4, 51.0, 35.9, 0.9],
[49.3, 32.9, 51.0, 35.3, 0.9],
[35.3, 11.5, 39.9, 14.5, 0.4],
[35.2, 11.7, 39.7, 15.7, 0.3]])
for device_id in range(torch.cuda.device_count()):
print('Run NMS on device_id = {!r}'.format(device_id))
# GPU can handle float32 but not float64
dets = base_dets.astype(np.float32)
supressed, inds = nms(dets, iou_thr, device_id)
assert dets.dtype == supressed.dtype
assert len(inds) == len(supressed) == 3
dets = torch.FloatTensor(base_dets).to(device_id)
surpressed, inds = nms(dets, iou_thr)
assert dets.dtype == surpressed.dtype
assert len(inds) == len(surpressed) == 3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment