Unverified Commit 30bb1cea authored by Justin Chu's avatar Justin Chu Committed by GitHub
Browse files

[ONNX] misc improvements (#7249)


Co-authored-by: default avatarNikita Shulga <nshulga@fb.com>
parent d805aeae
...@@ -34,7 +34,7 @@ class TestONNXExporter: ...@@ -34,7 +34,7 @@ class TestONNXExporter:
opset_version: Optional[int] = None, opset_version: Optional[int] = None,
): ):
if opset_version is None: if opset_version is None:
opset_version = _register_onnx_ops.base_onnx_opset_version opset_version = _register_onnx_ops.BASE_ONNX_OPSET_VERSION
model.eval() model.eval()
...@@ -139,7 +139,7 @@ class TestONNXExporter: ...@@ -139,7 +139,7 @@ class TestONNXExporter:
self.run_model(model, [(x, single_roi)]) self.run_model(model, [(x, single_roi)])
def test_roi_align_aligned(self): def test_roi_align_aligned(self):
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16 supported_onnx_version = _register_onnx_ops._ONNX_OPSET_VERSION_16
x = torch.rand(1, 1, 10, 10, dtype=torch.float32) x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32) single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 2, aligned=True) model = ops.RoIAlign((5, 5), 1, 2, aligned=True)
...@@ -166,7 +166,7 @@ class TestONNXExporter: ...@@ -166,7 +166,7 @@ class TestONNXExporter:
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version) self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
def test_roi_align_malformed_boxes(self): def test_roi_align_malformed_boxes(self):
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16 supported_onnx_version = _register_onnx_ops._ONNX_OPSET_VERSION_16
x = torch.randn(1, 1, 10, 10, dtype=torch.float32) x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32) single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 1, aligned=True) model = ops.RoIAlign((5, 5), 1, 1, aligned=True)
......
...@@ -2,96 +2,106 @@ import sys ...@@ -2,96 +2,106 @@ import sys
import warnings import warnings
import torch import torch
from torch.onnx import symbolic_opset11 as opset11
from torch.onnx.symbolic_helper import parse_args
_onnx_opset_version_11 = 11 _ONNX_OPSET_VERSION_11 = 11
_onnx_opset_version_16 = 16 _ONNX_OPSET_VERSION_16 = 16
base_onnx_opset_version = _onnx_opset_version_11 BASE_ONNX_OPSET_VERSION = _ONNX_OPSET_VERSION_11
def _register_custom_op(): @parse_args("v", "v", "f")
from torch.onnx.symbolic_helper import parse_args def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze boxes = opset11.unsqueeze(g, boxes, 0)
scores = opset11.unsqueeze(g, opset11.unsqueeze(g, scores, 0), 0)
@parse_args("v", "v", "f") max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long))
def symbolic_multi_label_nms(g, boxes, scores, iou_threshold): iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float))
boxes = unsqueeze(g, boxes, 0)
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0) # Cast boxes and scores to float32 in case they are float64 inputs
max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long)) nms_out = g.op(
iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float)) "NonMaxSuppression",
nms_out = g.op( g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT),
"NonMaxSuppression", g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT),
g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT), max_output_per_class,
g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT), iou_threshold,
max_output_per_class, )
iou_threshold, return opset11.squeeze(
) g, opset11.select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1
return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1) )
def _process_batch_indices_for_roi_align(g, rois):
indices = squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1) def _process_batch_indices_for_roi_align(g, rois):
return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64) indices = opset11.squeeze(
g, opset11.select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1
def _process_rois_for_roi_align(g, rois): )
return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long))) return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64)
def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int):
if sampling_ratio < 0: def _process_rois_for_roi_align(g, rois):
warnings.warn( return opset11.select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
"ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
"The model will be exported with a sampling_ratio of 0."
)
sampling_ratio = 0
return sampling_ratio
@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _process_batch_indices_for_roi_align(g, rois)
rois = _process_rois_for_roi_align(g, rois)
if aligned:
warnings.warn(
"ROIAlign with aligned=True is only supported in opset >= 16. "
"Please export with opset 16 or higher, or use aligned=False."
)
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
return g.op(
"RoiAlign",
input,
rois,
batch_indices,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
sampling_ratio_i=sampling_ratio,
)
@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int):
batch_indices = _process_batch_indices_for_roi_align(g, rois) if sampling_ratio < 0:
rois = _process_rois_for_roi_align(g, rois) warnings.warn(
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel" "ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio) "The model will be exported with a sampling_ratio of 0."
return g.op(
"RoiAlign",
input,
rois,
batch_indices,
coordinate_transformation_mode_s=coordinate_transformation_mode,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
sampling_ratio_i=sampling_ratio,
) )
sampling_ratio = 0
return sampling_ratio
@parse_args("v", "v", "f", "i", "i") @parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width): def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
roi_pool = g.op( batch_indices = _process_batch_indices_for_roi_align(g, rois)
"MaxRoiPool", input, rois, pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale rois = _process_rois_for_roi_align(g, rois)
if aligned:
warnings.warn(
"ROIAlign with aligned=True is only supported in opset >= 16. "
"Please export with opset 16 or higher, or use aligned=False."
) )
return roi_pool, None sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
return g.op(
"RoiAlign",
input,
rois,
batch_indices,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
sampling_ratio_i=sampling_ratio,
)
@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _process_batch_indices_for_roi_align(g, rois)
rois = _process_rois_for_roi_align(g, rois)
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
return g.op(
"RoiAlign",
input,
rois,
batch_indices,
coordinate_transformation_mode_s=coordinate_transformation_mode,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
sampling_ratio_i=sampling_ratio,
)
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version_11) @parse_args("v", "v", "f", "i", "i")
register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _onnx_opset_version_11) def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):
register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _onnx_opset_version_16) roi_pool = g.op(
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version_11) "MaxRoiPool", input, rois, pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale
)
return roi_pool, None
def _register_custom_op():
torch.onnx.register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _ONNX_OPSET_VERSION_11)
torch.onnx.register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _ONNX_OPSET_VERSION_11)
torch.onnx.register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _ONNX_OPSET_VERSION_16)
torch.onnx.register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _ONNX_OPSET_VERSION_11)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment