"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "b7d2d92895c20fb8f24e09ef6dbf0709a83d99fc"
Unverified Commit c991db82 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

[OPS, TEST] Add onnx test for batched_nms (#3483)

parent 668927ef
...@@ -82,9 +82,10 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -82,9 +82,10 @@ class ONNXExporterTester(unittest.TestCase):
raise raise
def test_nms(self): def test_nms(self):
boxes = torch.rand(5, 4) num_boxes = 100
boxes[:, 2:] += torch.rand(5, 2) boxes = torch.rand(num_boxes, 4)
scores = torch.randn(5) boxes[:, 2:] += boxes[:, :2]
scores = torch.randn(num_boxes)
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward(self, boxes, scores): def forward(self, boxes, scores):
...@@ -92,6 +93,19 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -92,6 +93,19 @@ class ONNXExporterTester(unittest.TestCase):
self.run_model(Module(), [(boxes, scores)]) self.run_model(Module(), [(boxes, scores)])
def test_batched_nms(self):
num_boxes = 100
boxes = torch.rand(num_boxes, 4)
boxes[:, 2:] += boxes[:, :2]
scores = torch.randn(num_boxes)
idxs = torch.randint(0, 5, size=(num_boxes,))
class Module(torch.nn.Module):
def forward(self, boxes, scores, idxs):
return ops.batched_nms(boxes, scores, idxs, 0.5)
self.run_model(Module(), [(boxes, scores, idxs)])
def test_clip_boxes_to_image(self): def test_clip_boxes_to_image(self):
boxes = torch.randn(5, 4) * 500 boxes = torch.randn(5, 4) * 500
boxes[:, 2:] += boxes[:, :2] boxes[:, 2:] += boxes[:, :2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment