Unverified Commit f627b9d1 authored by Justin Chu's avatar Justin Chu Committed by GitHub
Browse files

[ONNX] Fix dtype for NonMaxSuppression (#7056)


Co-authored-by: default avatarNikita Shulga <nshulga@fb.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent f9d1883e
......@@ -11,7 +11,6 @@ base_onnx_opset_version = _onnx_opset_version_11
def _register_custom_op():
from torch.onnx.symbolic_helper import parse_args
from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze
from torch.onnx.symbolic_opset9 import _cast_Long
@parse_args("v", "v", "f")
def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
......@@ -19,13 +18,18 @@ def _register_custom_op():
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long))
iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float))
nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold)
nms_out = g.op(
"NonMaxSuppression",
g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT),
g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT),
max_output_per_class,
iou_threshold,
)
return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1)
def _process_batch_indices_for_roi_align(g, rois):
return _cast_Long(
g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False
)
indices = squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1)
return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64)
def _process_rois_for_roi_align(g, rois):
return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment