Unverified Commit 559b0558 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #119 from youansheng/patch-3

Update nms_wrapper.py
parents 3f5df4f0 5f9ca54d
......@@ -9,7 +9,9 @@ from .cpu_soft_nms import cpu_soft_nms
def nms(dets, thresh, device_id=None):
"""Dispatch to either CPU or GPU NMS implementations."""
tensor_device = None
if isinstance(dets, torch.Tensor):
tensor_device = dets.device
if dets.is_cuda:
device_id = dets.get_device()
dets = dets.detach().cpu().numpy()
......@@ -21,8 +23,8 @@ def nms(dets, thresh, device_id=None):
inds = (gpu_nms(dets, thresh, device_id=device_id)
if device_id is not None else cpu_nms(dets, thresh))
if isinstance(dets, torch.Tensor):
return dets.new_tensor(inds, dtype=torch.long)
if tensor_device:
return torch.Tensor(inds).long().to(tensor_device)
else:
return np.array(inds, dtype=np.int)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment