Unverified Commit f627b9d1 authored by Justin Chu's avatar Justin Chu Committed by GitHub
Browse files

[ONNX] Fix dtype for NonMaxSuppression (#7056)


Co-authored-by: default avatarNikita Shulga <nshulga@fb.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent f9d1883e
...@@ -11,7 +11,6 @@ base_onnx_opset_version = _onnx_opset_version_11 ...@@ -11,7 +11,6 @@ base_onnx_opset_version = _onnx_opset_version_11
def _register_custom_op(): def _register_custom_op():
from torch.onnx.symbolic_helper import parse_args from torch.onnx.symbolic_helper import parse_args
from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze
from torch.onnx.symbolic_opset9 import _cast_Long
@parse_args("v", "v", "f") @parse_args("v", "v", "f")
def symbolic_multi_label_nms(g, boxes, scores, iou_threshold): def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
...@@ -19,13 +18,18 @@ def _register_custom_op(): ...@@ -19,13 +18,18 @@ def _register_custom_op():
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0) scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long)) max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long))
iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float)) iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float))
nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold) nms_out = g.op(
"NonMaxSuppression",
g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT),
g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT),
max_output_per_class,
iou_threshold,
)
return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1) return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1)
def _process_batch_indices_for_roi_align(g, rois): def _process_batch_indices_for_roi_align(g, rois):
return _cast_Long( indices = squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1)
g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64)
)
def _process_rois_for_roi_align(g, rois): def _process_rois_for_roi_align(g, rois):
return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long))) return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment