Unverified Commit 30bb1cea authored by Justin Chu's avatar Justin Chu Committed by GitHub
Browse files

[ONNX] misc improvements (#7249)


Co-authored-by: default avatarNikita Shulga <nshulga@fb.com>
parent d805aeae
......@@ -34,7 +34,7 @@ class TestONNXExporter:
opset_version: Optional[int] = None,
):
if opset_version is None:
opset_version = _register_onnx_ops.base_onnx_opset_version
opset_version = _register_onnx_ops.BASE_ONNX_OPSET_VERSION
model.eval()
......@@ -139,7 +139,7 @@ class TestONNXExporter:
self.run_model(model, [(x, single_roi)])
def test_roi_align_aligned(self):
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
supported_onnx_version = _register_onnx_ops._ONNX_OPSET_VERSION_16
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 2, aligned=True)
......@@ -166,7 +166,7 @@ class TestONNXExporter:
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
def test_roi_align_malformed_boxes(self):
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
supported_onnx_version = _register_onnx_ops._ONNX_OPSET_VERSION_16
x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 1, aligned=True)
......
......@@ -2,22 +2,22 @@ import sys
import warnings
import torch
from torch.onnx import symbolic_opset11 as opset11
from torch.onnx.symbolic_helper import parse_args
_onnx_opset_version_11 = 11
_onnx_opset_version_16 = 16
base_onnx_opset_version = _onnx_opset_version_11
_ONNX_OPSET_VERSION_11 = 11
_ONNX_OPSET_VERSION_16 = 16
BASE_ONNX_OPSET_VERSION = _ONNX_OPSET_VERSION_11
def _register_custom_op():
from torch.onnx.symbolic_helper import parse_args
from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze
@parse_args("v", "v", "f")
def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
boxes = unsqueeze(g, boxes, 0)
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
@parse_args("v", "v", "f")
def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
boxes = opset11.unsqueeze(g, boxes, 0)
scores = opset11.unsqueeze(g, opset11.unsqueeze(g, scores, 0), 0)
max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long))
iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float))
# Cast boxes and scores to float32 in case they are float64 inputs
nms_out = g.op(
"NonMaxSuppression",
g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT),
......@@ -25,16 +25,23 @@ def _register_custom_op():
max_output_per_class,
iou_threshold,
)
return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1)
return opset11.squeeze(
g, opset11.select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1
)
def _process_batch_indices_for_roi_align(g, rois):
indices = squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1)
def _process_batch_indices_for_roi_align(g, rois):
indices = opset11.squeeze(
g, opset11.select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1
)
return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64)
def _process_rois_for_roi_align(g, rois):
return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int):
def _process_rois_for_roi_align(g, rois):
return opset11.select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int):
if sampling_ratio < 0:
warnings.warn(
"ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
......@@ -43,8 +50,9 @@ def _register_custom_op():
sampling_ratio = 0
return sampling_ratio
@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _process_batch_indices_for_roi_align(g, rois)
rois = _process_rois_for_roi_align(g, rois)
if aligned:
......@@ -64,8 +72,9 @@ def _register_custom_op():
sampling_ratio_i=sampling_ratio,
)
@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _process_batch_indices_for_roi_align(g, rois)
rois = _process_rois_for_roi_align(g, rois)
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
......@@ -82,16 +91,17 @@ def _register_custom_op():
sampling_ratio_i=sampling_ratio,
)
@parse_args("v", "v", "f", "i", "i")
def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):
@parse_args("v", "v", "f", "i", "i")
def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):
roi_pool = g.op(
"MaxRoiPool", input, rois, pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale
)
return roi_pool, None
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version_11)
register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _onnx_opset_version_11)
register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _onnx_opset_version_16)
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version_11)
def _register_custom_op():
torch.onnx.register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _ONNX_OPSET_VERSION_11)
torch.onnx.register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _ONNX_OPSET_VERSION_11)
torch.onnx.register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _ONNX_OPSET_VERSION_16)
torch.onnx.register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _ONNX_OPSET_VERSION_11)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment