Unverified Commit 11268ca7 authored by Ksenija Stanojevic's avatar Ksenija Stanojevic Committed by GitHub
Browse files

[ONNX] Fix roi_align ONNX export (#3355)



* add tests

* fix bug

* remove tests

* fix comment

* fix comment

* add warning

* fix syntax error

* fix python lint
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 9bccd5aa
......@@ -129,6 +129,11 @@ class ONNXExporterTester(unittest.TestCase):
model = ops.RoIAlign((5, 5), 1, 2)
self.run_model(model, [(x, single_roi)])
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, -1)
self.run_model(model, [(x, single_roi)])
def test_roi_align_aligned(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
......@@ -150,6 +155,11 @@ class ONNXExporterTester(unittest.TestCase):
model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
self.run_model(model, [(x, single_roi)])
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
self.run_model(model, [(x, single_roi)])
@unittest.skip # Issue in exporting ROIAlign with aligned = True for malformed boxes
def test_roi_align_malformed_boxes(self):
x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
......
......@@ -29,6 +29,12 @@ def _register_custom_op():
" ONNX forces ROIs to be 1x1 or larger.")
scale = torch.tensor(0.5 / spatial_scale).to(dtype=torch.float)
rois = g.op("Sub", rois, scale)
# ONNX doesn't support negative sampling_ratio
if sampling_ratio < 0:
warnings.warn("ONNX doesn't support negative sampling ratio,"
"therefore is is set to 0 in order to be exported.")
sampling_ratio = 0
return g.op('RoiAlign', input, rois, batch_indices, spatial_scale_f=spatial_scale,
output_height_i=pooled_height, output_width_i=pooled_width, sampling_ratio_i=sampling_ratio)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment