"src/diffusers/models/modeling_flax_utils.py" did not exist on "8aac1f99d7af5873db7d23c07fba370d0f5061a6"
Unverified Commit 5f9ca54d authored by YouAnsheng's avatar YouAnsheng Committed by GitHub
Browse files

Update nms_wrapper.py

Fix bug for nms_wrapper.
parent 3f5df4f0
...@@ -9,7 +9,9 @@ from .cpu_soft_nms import cpu_soft_nms ...@@ -9,7 +9,9 @@ from .cpu_soft_nms import cpu_soft_nms
def nms(dets, thresh, device_id=None): def nms(dets, thresh, device_id=None):
"""Dispatch to either CPU or GPU NMS implementations.""" """Dispatch to either CPU or GPU NMS implementations."""
tensor_device = None
if isinstance(dets, torch.Tensor): if isinstance(dets, torch.Tensor):
tensor_device = dets.device
if dets.is_cuda: if dets.is_cuda:
device_id = dets.get_device() device_id = dets.get_device()
dets = dets.detach().cpu().numpy() dets = dets.detach().cpu().numpy()
...@@ -21,8 +23,8 @@ def nms(dets, thresh, device_id=None): ...@@ -21,8 +23,8 @@ def nms(dets, thresh, device_id=None):
inds = (gpu_nms(dets, thresh, device_id=device_id) inds = (gpu_nms(dets, thresh, device_id=device_id)
if device_id is not None else cpu_nms(dets, thresh)) if device_id is not None else cpu_nms(dets, thresh))
if isinstance(dets, torch.Tensor): if tensor_device:
return dets.new_tensor(inds, dtype=torch.long) return torch.Tensor(inds).long().to(tensor_device)
else: else:
return np.array(inds, dtype=np.int) return np.array(inds, dtype=np.int)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment