Unverified Commit 45f87fa3 authored by Bowen Bao's avatar Bowen Bao Committed by GitHub
Browse files

[ONNX] Support exporting RoiAlign align=True to ONNX with opset 16 (#6685)



* Support exporting RoiAlign align=True to ONNX with opset 16

* lint: ufmt
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 344ccc05
import io import io
from collections import OrderedDict from collections import OrderedDict
from typing import List, Tuple from typing import List, Optional, Tuple
import pytest import pytest
import torch import torch
...@@ -11,7 +11,7 @@ from torchvision.models.detection.image_list import ImageList ...@@ -11,7 +11,7 @@ from torchvision.models.detection.image_list import ImageList
from torchvision.models.detection.roi_heads import RoIHeads from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
from torchvision.models.detection.transform import GeneralizedRCNNTransform from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.ops._register_onnx_ops import _onnx_opset_version from torchvision.ops import _register_onnx_ops
# In environments without onnxruntime we prefer to # In environments without onnxruntime we prefer to
# invoke all tests in the repo and have this one skipped rather than fail. # invoke all tests in the repo and have this one skipped rather than fail.
...@@ -32,7 +32,11 @@ class TestONNXExporter: ...@@ -32,7 +32,11 @@ class TestONNXExporter:
dynamic_axes=None, dynamic_axes=None,
output_names=None, output_names=None,
input_names=None, input_names=None,
opset_version: Optional[int] = None,
): ):
if opset_version is None:
opset_version = _register_onnx_ops.base_onnx_opset_version
model.eval() model.eval()
onnx_io = io.BytesIO() onnx_io = io.BytesIO()
...@@ -46,10 +50,11 @@ class TestONNXExporter: ...@@ -46,10 +50,11 @@ class TestONNXExporter:
torch_onnx_input, torch_onnx_input,
onnx_io, onnx_io,
do_constant_folding=do_constant_folding, do_constant_folding=do_constant_folding,
opset_version=_onnx_opset_version, opset_version=opset_version,
dynamic_axes=dynamic_axes, dynamic_axes=dynamic_axes,
input_names=input_names, input_names=input_names,
output_names=output_names, output_names=output_names,
verbose=True,
) )
# validate the exported model with onnx runtime # validate the exported model with onnx runtime
for test_inputs in inputs_list: for test_inputs in inputs_list:
...@@ -140,39 +145,39 @@ class TestONNXExporter: ...@@ -140,39 +145,39 @@ class TestONNXExporter:
model = ops.RoIAlign((5, 5), 1, -1) model = ops.RoIAlign((5, 5), 1, -1)
self.run_model(model, [(x, single_roi)]) self.run_model(model, [(x, single_roi)])
@pytest.mark.skip(reason="ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16.")
def test_roi_align_aligned(self): def test_roi_align_aligned(self):
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
x = torch.rand(1, 1, 10, 10, dtype=torch.float32) x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32) single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 2, aligned=True) model = ops.RoIAlign((5, 5), 1, 2, aligned=True)
self.run_model(model, [(x, single_roi)]) self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
x = torch.rand(1, 1, 10, 10, dtype=torch.float32) x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32) single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 0.5, 3, aligned=True) model = ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
self.run_model(model, [(x, single_roi)]) self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
x = torch.rand(1, 1, 10, 10, dtype=torch.float32) x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32) single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1.8, 2, aligned=True) model = ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
self.run_model(model, [(x, single_roi)]) self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
x = torch.rand(1, 1, 10, 10, dtype=torch.float32) x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32) single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True) model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
self.run_model(model, [(x, single_roi)]) self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
x = torch.rand(1, 1, 10, 10, dtype=torch.float32) x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32) single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True) model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
self.run_model(model, [(x, single_roi)]) self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
@pytest.mark.skip(reason="Issue in exporting ROIAlign with aligned = True for malformed boxes")
def test_roi_align_malformed_boxes(self): def test_roi_align_malformed_boxes(self):
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
x = torch.randn(1, 1, 10, 10, dtype=torch.float32) x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32) single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 1, aligned=True) model = ops.RoIAlign((5, 5), 1, 1, aligned=True)
self.run_model(model, [(x, single_roi)]) self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
def test_roi_pool(self): def test_roi_pool(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32) x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
......
...@@ -3,7 +3,9 @@ import warnings ...@@ -3,7 +3,9 @@ import warnings
import torch import torch
_onnx_opset_version = 11 _onnx_opset_version_11 = 11
_onnx_opset_version_16 = 16
base_onnx_opset_version = _onnx_opset_version_11
def _register_custom_op(): def _register_custom_op():
...@@ -20,32 +22,56 @@ def _register_custom_op(): ...@@ -20,32 +22,56 @@ def _register_custom_op():
nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold) nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold)
return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1) return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1)
@parse_args("v", "v", "f", "i", "i", "i", "i") def _process_batch_indices_for_roi_align(g, rois):
def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): return _cast_Long(
batch_indices = _cast_Long(
g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False
) )
rois = select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
# TODO: Remove this warning after ONNX opset 16 is supported.
if aligned:
warnings.warn(
"ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16. "
"The workaround is that the user need apply the patch "
"https://github.com/microsoft/onnxruntime/pull/8564 "
"and build ONNXRuntime from source."
)
# ONNX doesn't support negative sampling_ratio def _process_rois_for_roi_align(g, rois):
return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int):
if sampling_ratio < 0: if sampling_ratio < 0:
warnings.warn( warnings.warn(
"ONNX doesn't support negative sampling ratio, therefore is set to 0 in order to be exported." "ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
"The model will be exported with a sampling_ratio of 0."
) )
sampling_ratio = 0 sampling_ratio = 0
return sampling_ratio
@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _process_batch_indices_for_roi_align(g, rois)
rois = _process_rois_for_roi_align(g, rois)
if aligned:
warnings.warn(
"ROIAlign with aligned=True is not supported in ONNX, but is supported in opset 16. "
"Please export with opset 16 or higher to use aligned=False."
)
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
return g.op(
"RoiAlign",
input,
rois,
batch_indices,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
sampling_ratio_i=sampling_ratio,
)
@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _process_batch_indices_for_roi_align(g, rois)
rois = _process_rois_for_roi_align(g, rois)
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
return g.op( return g.op(
"RoiAlign", "RoiAlign",
input, input,
rois, rois,
batch_indices, batch_indices,
coordinate_transformation_mode_s=coordinate_transformation_mode,
spatial_scale_f=spatial_scale, spatial_scale_f=spatial_scale,
output_height_i=pooled_height, output_height_i=pooled_height,
output_width_i=pooled_width, output_width_i=pooled_width,
...@@ -61,6 +87,7 @@ def _register_custom_op(): ...@@ -61,6 +87,7 @@ def _register_custom_op():
from torch.onnx import register_custom_op_symbolic from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version) register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version_11)
register_custom_op_symbolic("torchvision::roi_align", roi_align, _onnx_opset_version) register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _onnx_opset_version_11)
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version) register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _onnx_opset_version_16)
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version_11)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment