Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
30bb1cea
Unverified
Commit
30bb1cea
authored
Feb 15, 2023
by
Justin Chu
Committed by
GitHub
Feb 15, 2023
Browse files
[ONNX] misc improvements (#7249)
Co-authored-by:
Nikita Shulga
<
nshulga@fb.com
>
parent
d805aeae
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
87 deletions
+97
-87
test/test_onnx.py
test/test_onnx.py
+3
-3
torchvision/ops/_register_onnx_ops.py
torchvision/ops/_register_onnx_ops.py
+94
-84
No files found.
test/test_onnx.py
View file @
30bb1cea
...
...
@@ -34,7 +34,7 @@ class TestONNXExporter:
opset_version
:
Optional
[
int
]
=
None
,
):
if
opset_version
is
None
:
opset_version
=
_register_onnx_ops
.
base_onnx_opset_version
opset_version
=
_register_onnx_ops
.
BASE_ONNX_OPSET_VERSION
model
.
eval
()
...
...
@@ -139,7 +139,7 @@ class TestONNXExporter:
self
.
run_model
(
model
,
[(
x
,
single_roi
)])
def
test_roi_align_aligned
(
self
):
supported_onnx_version
=
_register_onnx_ops
.
_
onnx_opset_version
_16
supported_onnx_version
=
_register_onnx_ops
.
_
ONNX_OPSET_VERSION
_16
x
=
torch
.
rand
(
1
,
1
,
10
,
10
,
dtype
=
torch
.
float32
)
single_roi
=
torch
.
tensor
([[
0
,
1.5
,
1.5
,
3
,
3
]],
dtype
=
torch
.
float32
)
model
=
ops
.
RoIAlign
((
5
,
5
),
1
,
2
,
aligned
=
True
)
...
...
@@ -166,7 +166,7 @@ class TestONNXExporter:
self
.
run_model
(
model
,
[(
x
,
single_roi
)],
opset_version
=
supported_onnx_version
)
def
test_roi_align_malformed_boxes
(
self
):
supported_onnx_version
=
_register_onnx_ops
.
_
onnx_opset_version
_16
supported_onnx_version
=
_register_onnx_ops
.
_
ONNX_OPSET_VERSION
_16
x
=
torch
.
randn
(
1
,
1
,
10
,
10
,
dtype
=
torch
.
float32
)
single_roi
=
torch
.
tensor
([[
0
,
2
,
0.3
,
1.5
,
1.5
]],
dtype
=
torch
.
float32
)
model
=
ops
.
RoIAlign
((
5
,
5
),
1
,
1
,
aligned
=
True
)
...
...
torchvision/ops/_register_onnx_ops.py
View file @
30bb1cea
...
...
@@ -2,22 +2,22 @@ import sys
import
warnings
import
torch
from
torch.onnx
import
symbolic_opset11
as
opset11
from
torch.onnx.symbolic_helper
import
parse_args
_
onnx_opset_version
_11
=
11
_
onnx_opset_version
_16
=
16
base_onnx_opset_version
=
_onnx_opset_version
_11
_
ONNX_OPSET_VERSION
_11
=
11
_
ONNX_OPSET_VERSION
_16
=
16
BASE_ONNX_OPSET_VERSION
=
_ONNX_OPSET_VERSION
_11
def
_register_custom_op
():
from
torch.onnx.symbolic_helper
import
parse_args
from
torch.onnx.symbolic_opset11
import
select
,
squeeze
,
unsqueeze
@
parse_args
(
"v"
,
"v"
,
"f"
)
def
symbolic_multi_label_nms
(
g
,
boxes
,
scores
,
iou_threshold
):
boxes
=
unsqueeze
(
g
,
boxes
,
0
)
scores
=
unsqueeze
(
g
,
unsqueeze
(
g
,
scores
,
0
),
0
)
@
parse_args
(
"v"
,
"v"
,
"f"
)
def
symbolic_multi_label_nms
(
g
,
boxes
,
scores
,
iou_threshold
):
boxes
=
opset11
.
unsqueeze
(
g
,
boxes
,
0
)
scores
=
opset11
.
unsqueeze
(
g
,
opset11
.
unsqueeze
(
g
,
scores
,
0
),
0
)
max_output_per_class
=
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
sys
.
maxsize
],
dtype
=
torch
.
long
))
iou_threshold
=
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
iou_threshold
],
dtype
=
torch
.
float
))
# Cast boxes and scores to float32 in case they are float64 inputs
nms_out
=
g
.
op
(
"NonMaxSuppression"
,
g
.
op
(
"Cast"
,
boxes
,
to_i
=
torch
.
onnx
.
TensorProtoDataType
.
FLOAT
),
...
...
@@ -25,16 +25,23 @@ def _register_custom_op():
max_output_per_class
,
iou_threshold
,
)
return
squeeze
(
g
,
select
(
g
,
nms_out
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
2
],
dtype
=
torch
.
long
))),
1
)
return
opset11
.
squeeze
(
g
,
opset11
.
select
(
g
,
nms_out
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
2
],
dtype
=
torch
.
long
))),
1
)
def
_process_batch_indices_for_roi_align
(
g
,
rois
):
indices
=
squeeze
(
g
,
select
(
g
,
rois
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
))),
1
)
def
_process_batch_indices_for_roi_align
(
g
,
rois
):
indices
=
opset11
.
squeeze
(
g
,
opset11
.
select
(
g
,
rois
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
))),
1
)
return
g
.
op
(
"Cast"
,
indices
,
to_i
=
torch
.
onnx
.
TensorProtoDataType
.
INT64
)
def
_process_rois_for_roi_align
(
g
,
rois
):
return
select
(
g
,
rois
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
1
,
2
,
3
,
4
],
dtype
=
torch
.
long
)))
def
_process_sampling_ratio_for_roi_align
(
g
,
sampling_ratio
:
int
):
def
_process_rois_for_roi_align
(
g
,
rois
):
return
opset11
.
select
(
g
,
rois
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
1
,
2
,
3
,
4
],
dtype
=
torch
.
long
)))
def
_process_sampling_ratio_for_roi_align
(
g
,
sampling_ratio
:
int
):
if
sampling_ratio
<
0
:
warnings
.
warn
(
"ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
...
...
@@ -43,8 +50,9 @@ def _register_custom_op():
sampling_ratio
=
0
return
sampling_ratio
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
,
"i"
,
"i"
)
def
roi_align_opset11
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
):
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
,
"i"
,
"i"
)
def
roi_align_opset11
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
):
batch_indices
=
_process_batch_indices_for_roi_align
(
g
,
rois
)
rois
=
_process_rois_for_roi_align
(
g
,
rois
)
if
aligned
:
...
...
@@ -64,8 +72,9 @@ def _register_custom_op():
sampling_ratio_i
=
sampling_ratio
,
)
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
,
"i"
,
"i"
)
def
roi_align_opset16
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
):
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
,
"i"
,
"i"
)
def
roi_align_opset16
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
):
batch_indices
=
_process_batch_indices_for_roi_align
(
g
,
rois
)
rois
=
_process_rois_for_roi_align
(
g
,
rois
)
coordinate_transformation_mode
=
"half_pixel"
if
aligned
else
"output_half_pixel"
...
...
@@ -82,16 +91,17 @@ def _register_custom_op():
sampling_ratio_i
=
sampling_ratio
,
)
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
)
def
roi_pool
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
):
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
)
def
roi_pool
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
):
roi_pool
=
g
.
op
(
"MaxRoiPool"
,
input
,
rois
,
pooled_shape_i
=
(
pooled_height
,
pooled_width
),
spatial_scale_f
=
spatial_scale
)
return
roi_pool
,
None
from
torch.onnx
import
register_custom_op_symbolic
register_custom_op_symbolic
(
"torchvision::nms"
,
symbolic_multi_label_nms
,
_onnx_opset_version_11
)
register_custom_op_symbolic
(
"torchvision::roi_align"
,
roi_align_opset11
,
_onnx_opset_version_11
)
register_custom_op_symbolic
(
"torchvision::roi_align"
,
roi_align_opset16
,
_onnx_opset_version_16
)
register_custom_op_symbolic
(
"torchvision::roi_pool"
,
roi_pool
,
_onnx_opset_version_11
)
def
_register_custom_op
():
torch
.
onnx
.
register_custom_op_symbolic
(
"torchvision::nms"
,
symbolic_multi_label_nms
,
_ONNX_OPSET_VERSION_11
)
torch
.
onnx
.
register_custom_op_symbolic
(
"torchvision::roi_align"
,
roi_align_opset11
,
_ONNX_OPSET_VERSION_11
)
torch
.
onnx
.
register_custom_op_symbolic
(
"torchvision::roi_align"
,
roi_align_opset16
,
_ONNX_OPSET_VERSION_16
)
torch
.
onnx
.
register_custom_op_symbolic
(
"torchvision::roi_pool"
,
roi_pool
,
_ONNX_OPSET_VERSION_11
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment