Unverified Commit 45f87fa3 authored by Bowen Bao's avatar Bowen Bao Committed by GitHub
Browse files

[ONNX] Support exporting RoiAlign align=True to ONNX with opset 16 (#6685)



* Support exporting RoiAlign align=True to ONNX with opset 16

* lint: ufmt
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 344ccc05
import io
from collections import OrderedDict
from typing import List, Tuple
from typing import List, Optional, Tuple
import pytest
import torch
......@@ -11,7 +11,7 @@ from torchvision.models.detection.image_list import ImageList
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.ops._register_onnx_ops import _onnx_opset_version
from torchvision.ops import _register_onnx_ops
# In environments without onnxruntime we prefer to
# invoke all tests in the repo and have this one skipped rather than fail.
......@@ -32,7 +32,11 @@ class TestONNXExporter:
dynamic_axes=None,
output_names=None,
input_names=None,
opset_version: Optional[int] = None,
):
if opset_version is None:
opset_version = _register_onnx_ops.base_onnx_opset_version
model.eval()
onnx_io = io.BytesIO()
......@@ -46,10 +50,11 @@ class TestONNXExporter:
torch_onnx_input,
onnx_io,
do_constant_folding=do_constant_folding,
opset_version=_onnx_opset_version,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
input_names=input_names,
output_names=output_names,
verbose=True,
)
# validate the exported model with onnx runtime
for test_inputs in inputs_list:
......@@ -140,39 +145,39 @@ class TestONNXExporter:
model = ops.RoIAlign((5, 5), 1, -1)
self.run_model(model, [(x, single_roi)])
@pytest.mark.skip(reason="ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16.")
def test_roi_align_aligned(self):
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 2, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
@pytest.mark.skip(reason="Issue in exporting ROIAlign with aligned = True for malformed boxes")
def test_roi_align_malformed_boxes(self):
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 1, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
def test_roi_pool(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
......
......@@ -3,7 +3,9 @@ import warnings
import torch
_onnx_opset_version = 11
_onnx_opset_version_11 = 11
_onnx_opset_version_16 = 16
base_onnx_opset_version = _onnx_opset_version_11
def _register_custom_op():
......@@ -20,32 +22,56 @@ def _register_custom_op():
nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold)
return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1)
@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _cast_Long(
def _process_batch_indices_for_roi_align(g, rois):
return _cast_Long(
g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False
)
rois = select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
# TODO: Remove this warning after ONNX opset 16 is supported.
if aligned:
warnings.warn(
"ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16. "
"The workaround is that the user need apply the patch "
"https://github.com/microsoft/onnxruntime/pull/8564 "
"and build ONNXRuntime from source."
)
# ONNX doesn't support negative sampling_ratio
def _process_rois_for_roi_align(g, rois):
return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int):
if sampling_ratio < 0:
warnings.warn(
"ONNX doesn't support negative sampling ratio, therefore is set to 0 in order to be exported."
"ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
"The model will be exported with a sampling_ratio of 0."
)
sampling_ratio = 0
return sampling_ratio
@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _process_batch_indices_for_roi_align(g, rois)
rois = _process_rois_for_roi_align(g, rois)
if aligned:
warnings.warn(
"ROIAlign with aligned=True is not supported in ONNX, but is supported in opset 16. "
"Please export with opset 16 or higher to use aligned=False."
)
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
return g.op(
"RoiAlign",
input,
rois,
batch_indices,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
sampling_ratio_i=sampling_ratio,
)
@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _process_batch_indices_for_roi_align(g, rois)
rois = _process_rois_for_roi_align(g, rois)
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
return g.op(
"RoiAlign",
input,
rois,
batch_indices,
coordinate_transformation_mode_s=coordinate_transformation_mode,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
......@@ -61,6 +87,7 @@ def _register_custom_op():
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version)
register_custom_op_symbolic("torchvision::roi_align", roi_align, _onnx_opset_version)
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version)
register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version_11)
register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _onnx_opset_version_11)
register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _onnx_opset_version_16)
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version_11)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment