Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
279fca56
Unverified
Commit
279fca56
authored
Aug 27, 2020
by
Negin Raoof
Committed by
GitHub
Aug 27, 2020
Browse files
[ONNX] Export ROIAlign with aligned=True (#2613)
* Add support for export ROIAlign * Fix for feedback * flake8
parent
6f028212
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
2 deletions
+35
-2
test/test_onnx.py
test/test_onnx.py
+29
-0
torchvision/ops/_register_onnx_ops.py
torchvision/ops/_register_onnx_ops.py
+6
-2
No files found.
test/test_onnx.py
View file @
279fca56
...
@@ -66,6 +66,7 @@ class ONNXExporterTester(unittest.TestCase):
...
@@ -66,6 +66,7 @@ class ONNXExporterTester(unittest.TestCase):
# compute onnxruntime output prediction
# compute onnxruntime output prediction
ort_inputs
=
dict
((
ort_session
.
get_inputs
()[
i
].
name
,
inpt
)
for
i
,
inpt
in
enumerate
(
inputs
))
ort_inputs
=
dict
((
ort_session
.
get_inputs
()[
i
].
name
,
inpt
)
for
i
,
inpt
in
enumerate
(
inputs
))
ort_outs
=
ort_session
.
run
(
None
,
ort_inputs
)
ort_outs
=
ort_session
.
run
(
None
,
ort_inputs
)
for
i
in
range
(
0
,
len
(
outputs
)):
for
i
in
range
(
0
,
len
(
outputs
)):
try
:
try
:
torch
.
testing
.
assert_allclose
(
outputs
[
i
],
ort_outs
[
i
],
rtol
=
1e-03
,
atol
=
1e-05
)
torch
.
testing
.
assert_allclose
(
outputs
[
i
],
ort_outs
[
i
],
rtol
=
1e-03
,
atol
=
1e-05
)
...
@@ -121,6 +122,34 @@ class ONNXExporterTester(unittest.TestCase):
...
@@ -121,6 +122,34 @@ class ONNXExporterTester(unittest.TestCase):
model
=
ops
.
RoIAlign
((
5
,
5
),
1
,
2
)
model
=
ops
.
RoIAlign
((
5
,
5
),
1
,
2
)
self
.
run_model
(
model
,
[(
x
,
single_roi
)])
self
.
run_model
(
model
,
[(
x
,
single_roi
)])
def
test_roi_align_aligned
(
self
):
x
=
torch
.
rand
(
1
,
1
,
10
,
10
,
dtype
=
torch
.
float32
)
single_roi
=
torch
.
tensor
([[
0
,
1.5
,
1.5
,
3
,
3
]],
dtype
=
torch
.
float32
)
model
=
ops
.
RoIAlign
((
5
,
5
),
1
,
2
,
aligned
=
True
)
self
.
run_model
(
model
,
[(
x
,
single_roi
)])
x
=
torch
.
rand
(
1
,
1
,
10
,
10
,
dtype
=
torch
.
float32
)
single_roi
=
torch
.
tensor
([[
0
,
0.2
,
0.3
,
4.5
,
3.5
]],
dtype
=
torch
.
float32
)
model
=
ops
.
RoIAlign
((
5
,
5
),
0.5
,
3
,
aligned
=
True
)
self
.
run_model
(
model
,
[(
x
,
single_roi
)])
x
=
torch
.
rand
(
1
,
1
,
10
,
10
,
dtype
=
torch
.
float32
)
single_roi
=
torch
.
tensor
([[
0
,
0.2
,
0.3
,
4.5
,
3.5
]],
dtype
=
torch
.
float32
)
model
=
ops
.
RoIAlign
((
5
,
5
),
1.8
,
2
,
aligned
=
True
)
self
.
run_model
(
model
,
[(
x
,
single_roi
)])
x
=
torch
.
rand
(
1
,
1
,
10
,
10
,
dtype
=
torch
.
float32
)
single_roi
=
torch
.
tensor
([[
0
,
0.2
,
0.3
,
4.5
,
3.5
]],
dtype
=
torch
.
float32
)
model
=
ops
.
RoIAlign
((
2
,
2
),
2.5
,
0
,
aligned
=
True
)
self
.
run_model
(
model
,
[(
x
,
single_roi
)])
@
unittest
.
skip
# Issue in exporting ROIAlign with aligned = True for malformed boxes
def
test_roi_align_malformed_boxes
(
self
):
x
=
torch
.
randn
(
1
,
1
,
10
,
10
,
dtype
=
torch
.
float32
)
single_roi
=
torch
.
tensor
([[
0
,
2
,
0.3
,
1.5
,
1.5
]],
dtype
=
torch
.
float32
)
model
=
ops
.
RoIAlign
((
5
,
5
),
1
,
1
,
aligned
=
True
)
self
.
run_model
(
model
,
[(
x
,
single_roi
)])
def
test_roi_pool
(
self
):
def
test_roi_pool
(
self
):
x
=
torch
.
rand
(
1
,
1
,
10
,
10
,
dtype
=
torch
.
float32
)
x
=
torch
.
rand
(
1
,
1
,
10
,
10
,
dtype
=
torch
.
float32
)
rois
=
torch
.
tensor
([[
0
,
0
,
0
,
4
,
4
]],
dtype
=
torch
.
float32
)
rois
=
torch
.
tensor
([[
0
,
0
,
0
,
4
,
4
]],
dtype
=
torch
.
float32
)
...
...
torchvision/ops/_register_onnx_ops.py
View file @
279fca56
import
sys
import
sys
import
torch
import
torch
import
warnings
_onnx_opset_version
=
11
_onnx_opset_version
=
11
...
@@ -20,11 +21,14 @@ def _register_custom_op():
...
@@ -20,11 +21,14 @@ def _register_custom_op():
@
parse_args
(
'v'
,
'v'
,
'f'
,
'i'
,
'i'
,
'i'
,
'i'
)
@
parse_args
(
'v'
,
'v'
,
'f'
,
'i'
,
'i'
,
'i'
,
'i'
)
def
roi_align
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
):
def
roi_align
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
):
if
(
aligned
):
raise
RuntimeError
(
'Unsupported: ONNX export of roi_align with aligned'
)
batch_indices
=
_cast_Long
(
g
,
squeeze
(
g
,
select
(
g
,
rois
,
1
,
g
.
op
(
'Constant'
,
batch_indices
=
_cast_Long
(
g
,
squeeze
(
g
,
select
(
g
,
rois
,
1
,
g
.
op
(
'Constant'
,
value_t
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
))),
1
),
False
)
value_t
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
))),
1
),
False
)
rois
=
select
(
g
,
rois
,
1
,
g
.
op
(
'Constant'
,
value_t
=
torch
.
tensor
([
1
,
2
,
3
,
4
],
dtype
=
torch
.
long
)))
rois
=
select
(
g
,
rois
,
1
,
g
.
op
(
'Constant'
,
value_t
=
torch
.
tensor
([
1
,
2
,
3
,
4
],
dtype
=
torch
.
long
)))
if
aligned
:
warnings
.
warn
(
"ONNX export of ROIAlign with aligned=True does not match PyTorch when using malformed boxes,"
" ONNX forces ROIs to be 1x1 or larger."
)
scale
=
torch
.
tensor
(
0.5
/
spatial_scale
).
to
(
dtype
=
torch
.
float
)
rois
=
g
.
op
(
"Sub"
,
rois
,
scale
)
return
g
.
op
(
'RoiAlign'
,
input
,
rois
,
batch_indices
,
spatial_scale_f
=
spatial_scale
,
return
g
.
op
(
'RoiAlign'
,
input
,
rois
,
batch_indices
,
spatial_scale_f
=
spatial_scale
,
output_height_i
=
pooled_height
,
output_width_i
=
pooled_width
,
sampling_ratio_i
=
sampling_ratio
)
output_height_i
=
pooled_height
,
output_width_i
=
pooled_width
,
sampling_ratio_i
=
sampling_ratio
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment