Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
be6f398c
Commit
be6f398c
authored
Nov 06, 2019
by
Lara Haidar
Committed by
Francisco Massa
Nov 06, 2019
Browse files
Enable ONNX Test for FasterRcnn (#1555)
* enable faster rcnn test * flake8 * smaller image size * set min/max
parent
af225a8a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
31 deletions
+18
-31
test/test_onnx.py
test/test_onnx.py
+11
-11
torchvision/models/detection/rpn.py
torchvision/models/detection/rpn.py
+1
-7
torchvision/models/detection/transform.py
torchvision/models/detection/transform.py
+1
-10
torchvision/ops/_register_onnx_ops.py
torchvision/ops/_register_onnx_ops.py
+5
-3
No files found.
test/test_onnx.py
View file @
be6f398c
...
@@ -19,6 +19,7 @@ except ImportError:
...
@@ -19,6 +19,7 @@ except ImportError:
onnxruntime
=
None
onnxruntime
=
None
import
unittest
import
unittest
from
torchvision.ops._register_onnx_ops
import
_onnx_opset_version
@
unittest
.
skipIf
(
onnxruntime
is
None
,
'ONNX Runtime unavailable'
)
@
unittest
.
skipIf
(
onnxruntime
is
None
,
'ONNX Runtime unavailable'
)
...
@@ -32,7 +33,8 @@ class ONNXExporterTester(unittest.TestCase):
...
@@ -32,7 +33,8 @@ class ONNXExporterTester(unittest.TestCase):
onnx_io
=
io
.
BytesIO
()
onnx_io
=
io
.
BytesIO
()
# export to onnx with the first input
# export to onnx with the first input
torch
.
onnx
.
export
(
model
,
inputs_list
[
0
],
onnx_io
,
do_constant_folding
=
True
,
opset_version
=
10
)
torch
.
onnx
.
export
(
model
,
inputs_list
[
0
],
onnx_io
,
do_constant_folding
=
True
,
opset_version
=
_onnx_opset_version
)
# validate the exported model with onnx runtime
# validate the exported model with onnx runtime
for
test_inputs
in
inputs_list
:
for
test_inputs
in
inputs_list
:
...
@@ -97,7 +99,6 @@ class ONNXExporterTester(unittest.TestCase):
...
@@ -97,7 +99,6 @@ class ONNXExporterTester(unittest.TestCase):
model
=
ops
.
RoIPool
((
pool_h
,
pool_w
),
2
)
model
=
ops
.
RoIPool
((
pool_h
,
pool_w
),
2
)
self
.
run_model
(
model
,
[(
x
,
rois
)])
self
.
run_model
(
model
,
[(
x
,
rois
)])
@
unittest
.
skip
(
"Disable test until Resize opset 11 is implemented in ONNX Runtime"
)
def
test_transform_images
(
self
):
def
test_transform_images
(
self
):
class
TransformModule
(
torch
.
nn
.
Module
):
class
TransformModule
(
torch
.
nn
.
Module
):
...
@@ -108,13 +109,13 @@ class ONNXExporterTester(unittest.TestCase):
...
@@ -108,13 +109,13 @@ class ONNXExporterTester(unittest.TestCase):
def
forward
(
self_module
,
images
):
def
forward
(
self_module
,
images
):
return
self_module
.
transform
(
images
)[
0
].
tensors
return
self_module
.
transform
(
images
)[
0
].
tensors
input
=
[
torch
.
rand
(
3
,
8
00
,
128
0
),
torch
.
rand
(
3
,
8
00
,
8
00
)]
input
=
[
torch
.
rand
(
3
,
1
00
,
20
0
),
torch
.
rand
(
3
,
2
00
,
2
00
)]
input_test
=
[
torch
.
rand
(
3
,
8
00
,
128
0
),
torch
.
rand
(
3
,
8
00
,
8
00
)]
input_test
=
[
torch
.
rand
(
3
,
1
00
,
20
0
),
torch
.
rand
(
3
,
2
00
,
2
00
)]
self
.
run_model
(
TransformModule
(),
[
input
,
input_test
])
self
.
run_model
(
TransformModule
(),
[
input
,
input_test
])
def
_init_test_generalized_rcnn_transform
(
self
):
def
_init_test_generalized_rcnn_transform
(
self
):
min_size
=
8
00
min_size
=
1
00
max_size
=
1333
max_size
=
200
image_mean
=
[
0.485
,
0.456
,
0.406
]
image_mean
=
[
0.485
,
0.456
,
0.406
]
image_std
=
[
0.229
,
0.224
,
0.225
]
image_std
=
[
0.229
,
0.224
,
0.225
]
transform
=
GeneralizedRCNNTransform
(
min_size
,
max_size
,
image_mean
,
image_std
)
transform
=
GeneralizedRCNNTransform
(
min_size
,
max_size
,
image_mean
,
image_std
)
...
@@ -234,7 +235,6 @@ class ONNXExporterTester(unittest.TestCase):
...
@@ -234,7 +235,6 @@ class ONNXExporterTester(unittest.TestCase):
self
.
run_model
(
TransformModule
(),
[(
i
,
[
boxes
],),
(
i1
,
[
boxes1
],)])
self
.
run_model
(
TransformModule
(),
[(
i
,
[
boxes
],),
(
i1
,
[
boxes1
],)])
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.4."
,
"Disable test if torch version is less than 1.4"
)
def
test_roi_heads
(
self
):
def
test_roi_heads
(
self
):
class
RoiHeadsModule
(
torch
.
nn
.
Module
):
class
RoiHeadsModule
(
torch
.
nn
.
Module
):
def
__init__
(
self_module
,
images
):
def
__init__
(
self_module
,
images
):
...
@@ -271,7 +271,7 @@ class ONNXExporterTester(unittest.TestCase):
...
@@ -271,7 +271,7 @@ class ONNXExporterTester(unittest.TestCase):
data
=
requests
.
get
(
url
)
data
=
requests
.
get
(
url
)
image
=
Image
.
open
(
BytesIO
(
data
.
content
)).
convert
(
"RGB"
)
image
=
Image
.
open
(
BytesIO
(
data
.
content
)).
convert
(
"RGB"
)
image
=
image
.
resize
((
8
00
,
128
0
),
Image
.
BILINEAR
)
image
=
image
.
resize
((
3
00
,
20
0
),
Image
.
BILINEAR
)
to_tensor
=
transforms
.
ToTensor
()
to_tensor
=
transforms
.
ToTensor
()
return
to_tensor
(
image
)
return
to_tensor
(
image
)
...
@@ -285,12 +285,12 @@ class ONNXExporterTester(unittest.TestCase):
...
@@ -285,12 +285,12 @@ class ONNXExporterTester(unittest.TestCase):
test_images
=
[
image2
]
test_images
=
[
image2
]
return
images
,
test_images
return
images
,
test_images
@
unittest
.
skip
(
"Disable test until Resize opset 11 is implemented in ONNX Runtime"
)
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.4."
,
"Disable test if torch version is less than 1.4"
)
def
test_faster_rcnn
(
self
):
def
test_faster_rcnn
(
self
):
images
,
test_images
=
self
.
get_test_images
()
images
,
test_images
=
self
.
get_test_images
()
model
=
models
.
detection
.
faster_rcnn
.
fasterrcnn_resnet50_fpn
(
pretrained
=
True
)
model
=
models
.
detection
.
faster_rcnn
.
fasterrcnn_resnet50_fpn
(
pretrained
=
True
,
min_size
=
200
,
max_size
=
300
)
model
.
eval
()
model
.
eval
()
model
(
images
)
model
(
images
)
self
.
run_model
(
model
,
[(
images
,),
(
test_images
,)])
self
.
run_model
(
model
,
[(
images
,),
(
test_images
,)])
...
...
torchvision/models/detection/rpn.py
View file @
be6f398c
...
@@ -110,13 +110,7 @@ class AnchorGenerator(nn.Module):
...
@@ -110,13 +110,7 @@ class AnchorGenerator(nn.Module):
shifts_y
=
torch
.
arange
(
shifts_y
=
torch
.
arange
(
0
,
grid_height
,
dtype
=
torch
.
float32
,
device
=
device
0
,
grid_height
,
dtype
=
torch
.
float32
,
device
=
device
)
*
stride_height
)
*
stride_height
# TODO: remove tracing pass when exporting torch.meshgrid()
shift_y
,
shift_x
=
torch
.
meshgrid
(
shifts_y
,
shifts_x
)
# is suported in ONNX
if
torchvision
.
_is_tracing
():
shift_y
=
shifts_y
.
view
(
-
1
,
1
).
expand
(
grid_height
,
grid_width
)
shift_x
=
shifts_x
.
view
(
1
,
-
1
).
expand
(
grid_height
,
grid_width
)
else
:
shift_y
,
shift_x
=
torch
.
meshgrid
(
shifts_y
,
shifts_x
)
shift_x
=
shift_x
.
reshape
(
-
1
)
shift_x
=
shift_x
.
reshape
(
-
1
)
shift_y
=
shift_y
.
reshape
(
-
1
)
shift_y
=
shift_y
.
reshape
(
-
1
)
shifts
=
torch
.
stack
((
shift_x
,
shift_y
,
shift_x
,
shift_y
),
dim
=
1
)
shifts
=
torch
.
stack
((
shift_x
,
shift_y
,
shift_x
,
shift_y
),
dim
=
1
)
...
...
torchvision/models/detection/transform.py
View file @
be6f398c
...
@@ -89,15 +89,6 @@ class GeneralizedRCNNTransform(nn.Module):
...
@@ -89,15 +89,6 @@ class GeneralizedRCNNTransform(nn.Module):
target
[
"keypoints"
]
=
keypoints
target
[
"keypoints"
]
=
keypoints
return
image
,
target
return
image
,
target
# _onnx_dynamic_img_pad() creates a dynamic padding
# for an image supported in ONNx tracing.
# it is used to process the images in _onnx_batch_images().
def
_onnx_dynamic_img_pad
(
self
,
img
,
padding
):
concat_0
=
torch
.
cat
((
img
,
torch
.
zeros
(
padding
[
0
],
img
.
shape
[
1
],
img
.
shape
[
2
])),
0
)
concat_1
=
torch
.
cat
((
concat_0
,
torch
.
zeros
(
concat_0
.
shape
[
0
],
padding
[
1
],
concat_0
.
shape
[
2
])),
1
)
padded_img
=
torch
.
cat
((
concat_1
,
torch
.
zeros
(
concat_1
.
shape
[
0
],
concat_1
.
shape
[
1
],
padding
[
2
])),
2
)
return
padded_img
# _onnx_batch_images() is an implementation of
# _onnx_batch_images() is an implementation of
# batch_images() that is supported by ONNX tracing.
# batch_images() that is supported by ONNX tracing.
def
_onnx_batch_images
(
self
,
images
,
size_divisible
=
32
):
def
_onnx_batch_images
(
self
,
images
,
size_divisible
=
32
):
...
@@ -116,7 +107,7 @@ class GeneralizedRCNNTransform(nn.Module):
...
@@ -116,7 +107,7 @@ class GeneralizedRCNNTransform(nn.Module):
padded_imgs
=
[]
padded_imgs
=
[]
for
img
in
images
:
for
img
in
images
:
padding
=
[(
s1
-
s2
)
for
s1
,
s2
in
zip
(
max_size
,
tuple
(
img
.
shape
))]
padding
=
[(
s1
-
s2
)
for
s1
,
s2
in
zip
(
max_size
,
tuple
(
img
.
shape
))]
padded_img
=
self
.
_onnx_dynamic_img_pad
(
img
,
padding
)
padded_img
=
torch
.
nn
.
functional
.
pad
(
img
,
(
0
,
padding
[
2
],
0
,
padding
[
1
],
0
,
padding
[
0
])
)
padded_imgs
.
append
(
padded_img
)
padded_imgs
.
append
(
padded_img
)
return
torch
.
stack
(
padded_imgs
)
return
torch
.
stack
(
padded_imgs
)
...
...
torchvision/ops/_register_onnx_ops.py
View file @
be6f398c
import
sys
import
sys
import
torch
import
torch
_onnx_opset_version
=
11
def
_register_custom_op
():
def
_register_custom_op
():
from
torch.onnx.symbolic_helper
import
parse_args
,
scalar_type_to_onnx
from
torch.onnx.symbolic_helper
import
parse_args
,
scalar_type_to_onnx
...
@@ -30,6 +32,6 @@ def _register_custom_op():
...
@@ -30,6 +32,6 @@ def _register_custom_op():
return
roi_pool
,
None
return
roi_pool
,
None
from
torch.onnx
import
register_custom_op_symbolic
from
torch.onnx
import
register_custom_op_symbolic
register_custom_op_symbolic
(
'torchvision::nms'
,
symbolic_multi_label_nms
,
10
)
register_custom_op_symbolic
(
'torchvision::nms'
,
symbolic_multi_label_nms
,
_onnx_opset_version
)
register_custom_op_symbolic
(
'torchvision::roi_align'
,
roi_align
,
10
)
register_custom_op_symbolic
(
'torchvision::roi_align'
,
roi_align
,
_onnx_opset_version
)
register_custom_op_symbolic
(
'torchvision::roi_pool'
,
roi_pool
,
10
)
register_custom_op_symbolic
(
'torchvision::roi_pool'
,
roi_pool
,
_onnx_opset_version
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment