"vscode:/vscode.git/clone" did not exist on "4f499c7ffb5f6746c4ccd87b93b0ef09c32cf424"
Unverified Commit 0b41ff0b authored by Edward Z. Yang's avatar Edward Z. Yang Committed by GitHub
Browse files

Meta implementation for nms (#7944)


Signed-off-by: default avatarEdward Z. Yang <ezyang@meta.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent d84aaae1
import functools import functools
import torch import torch
import torch._custom_ops
import torch.library import torch.library
# Ensure that torch.ops.torchvision is visible # Ensure that torch.ops.torchvision is visible
...@@ -48,3 +49,17 @@ def meta_roi_align_backward( ...@@ -48,3 +49,17 @@ def meta_roi_align_backward(
), ),
) )
return grad.new_empty((batch_size, channels, height, width)) return grad.new_empty((batch_size, channels, height, width))
@torch._custom_ops.impl_abstract("torchvision::nms")
def meta_nms(dets, scores, iou_threshold):
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}")
torch._check(
dets.size(0) == scores.size(0),
lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}",
)
ctx = torch._custom_ops.get_ctx()
num_to_keep = ctx.create_unbacked_symint()
return dets.new_empty(num_to_keep, dtype=torch.long)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment