Commit 6325eb85 authored by Louis Maddox's avatar Louis Maddox Committed by ericmintun
Browse files

Move batched NMS indices to correct device (closes #17)

parent 2780a301
......@@ -214,7 +214,7 @@ class SamAutomaticMaskGenerator:
keep_by_nms = batched_nms(
data["boxes"].float(),
scores,
torch.zeros(len(data["boxes"])), # categories
torch.zeros_like(data["boxes"][:,0]), # categories
iou_threshold=self.crop_nms_thresh,
)
data.filter(keep_by_nms)
......@@ -251,7 +251,7 @@ class SamAutomaticMaskGenerator:
keep_by_nms = batched_nms(
data["boxes"].float(),
data["iou_preds"],
torch.zeros(len(data["boxes"])), # categories
torch.zeros_like(data["boxes"][:,0]), # categories
iou_threshold=self.box_nms_thresh,
)
data.filter(keep_by_nms)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment