Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
fb2598b8
Unverified
Commit
fb2598b8
authored
Jun 11, 2021
by
Zhiqiang Wang
Committed by
GitHub
Jun 11, 2021
Browse files
Port test_onnx.py to pytest (#4047)
parent
552a4060
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
9 deletions
+9
-9
.circleci/config.yml
.circleci/config.yml
+1
-0
.circleci/config.yml.in
.circleci/config.yml.in
+1
-0
test/test_onnx.py
test/test_onnx.py
+7
-9
No files found.
.circleci/config.yml
View file @
fb2598b8
...
@@ -257,6 +257,7 @@ jobs:
...
@@ -257,6 +257,7 @@ jobs:
pip install --user --progress-bar off --editable .
pip install --user --progress-bar off --editable .
pip install --user onnx
pip install --user onnx
pip install --user onnxruntime
pip install --user onnxruntime
pip install --user pytest
python test/test_onnx.py
python test/test_onnx.py
binary_linux_wheel
:
binary_linux_wheel
:
...
...
.circleci/config.yml.in
View file @
fb2598b8
...
@@ -257,6 +257,7 @@ jobs:
...
@@ -257,6 +257,7 @@ jobs:
pip install --user --progress-bar off --editable .
pip install --user --progress-bar off --editable .
pip install --user onnx
pip install --user onnx
pip install --user onnxruntime
pip install --user onnxruntime
pip install --user pytest
python test/test_onnx.py
python test/test_onnx.py
binary_linux_wheel:
binary_linux_wheel:
...
...
test/test_onnx.py
View file @
fb2598b8
...
@@ -15,21 +15,19 @@ from torchvision import models
...
@@ -15,21 +15,19 @@ from torchvision import models
from
torchvision.models.detection.image_list
import
ImageList
from
torchvision.models.detection.image_list
import
ImageList
from
torchvision.models.detection.transform
import
GeneralizedRCNNTransform
from
torchvision.models.detection.transform
import
GeneralizedRCNNTransform
from
torchvision.models.detection.rpn
import
AnchorGenerator
,
RPNHead
,
RegionProposalNetwork
from
torchvision.models.detection.rpn
import
AnchorGenerator
,
RPNHead
,
RegionProposalNetwork
from
torchvision.models.detection.backbone_utils
import
resnet_fpn_backbone
from
torchvision.models.detection.roi_heads
import
RoIHeads
from
torchvision.models.detection.roi_heads
import
RoIHeads
from
torchvision.models.detection.faster_rcnn
import
FastRCNNPredictor
,
TwoMLPHead
from
torchvision.models.detection.faster_rcnn
import
FastRCNNPredictor
,
TwoMLPHead
from
torchvision.models.detection.mask_rcnn
import
MaskRCNNHeads
,
MaskRCNNPredictor
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
unit
test
import
py
test
from
torchvision.ops._register_onnx_ops
import
_onnx_opset_version
from
torchvision.ops._register_onnx_ops
import
_onnx_opset_version
@
unit
test
.
skip
I
f
(
onnxruntime
is
None
,
'ONNX Runtime unavailable'
)
@
py
test
.
mark
.
skip
i
f
(
onnxruntime
is
None
,
reason
=
'ONNX Runtime unavailable'
)
class
ONNXExporter
Tester
(
unittest
.
TestCase
)
:
class
Test
ONNXExporter
:
@
classmethod
@
classmethod
def
set
UpC
lass
(
cls
):
def
set
up_c
lass
(
cls
):
torch
.
manual_seed
(
123
)
torch
.
manual_seed
(
123
)
def
run_model
(
self
,
model
,
inputs_list
,
tolerate_small_mismatch
=
False
,
do_constant_folding
=
True
,
dynamic_axes
=
None
,
def
run_model
(
self
,
model
,
inputs_list
,
tolerate_small_mismatch
=
False
,
do_constant_folding
=
True
,
dynamic_axes
=
None
,
...
@@ -80,7 +78,7 @@ class ONNXExporterTester(unittest.TestCase):
...
@@ -80,7 +78,7 @@ class ONNXExporterTester(unittest.TestCase):
torch
.
testing
.
assert_allclose
(
outputs
[
i
],
ort_outs
[
i
],
rtol
=
1e-03
,
atol
=
1e-05
)
torch
.
testing
.
assert_allclose
(
outputs
[
i
],
ort_outs
[
i
],
rtol
=
1e-03
,
atol
=
1e-05
)
except
AssertionError
as
error
:
except
AssertionError
as
error
:
if
tolerate_small_mismatch
:
if
tolerate_small_mismatch
:
self
.
assert
In
(
"(0.00%)"
,
str
(
error
),
str
(
error
)
)
assert
"(0.00%)"
in
str
(
error
),
str
(
error
)
else
:
else
:
raise
raise
...
@@ -161,7 +159,7 @@ class ONNXExporterTester(unittest.TestCase):
...
@@ -161,7 +159,7 @@ class ONNXExporterTester(unittest.TestCase):
model
=
ops
.
RoIAlign
((
2
,
2
),
2.5
,
-
1
,
aligned
=
True
)
model
=
ops
.
RoIAlign
((
2
,
2
),
2.5
,
-
1
,
aligned
=
True
)
self
.
run_model
(
model
,
[(
x
,
single_roi
)])
self
.
run_model
(
model
,
[(
x
,
single_roi
)])
@
unit
test
.
skip
#
Issue in exporting ROIAlign with aligned = True for malformed boxes
@
py
test
.
mark
.
skip
(
reason
=
"
Issue in exporting ROIAlign with aligned = True for malformed boxes
"
)
def
test_roi_align_malformed_boxes
(
self
):
def
test_roi_align_malformed_boxes
(
self
):
x
=
torch
.
randn
(
1
,
1
,
10
,
10
,
dtype
=
torch
.
float32
)
x
=
torch
.
randn
(
1
,
1
,
10
,
10
,
dtype
=
torch
.
float32
)
single_roi
=
torch
.
tensor
([[
0
,
2
,
0.3
,
1.5
,
1.5
]],
dtype
=
torch
.
float32
)
single_roi
=
torch
.
tensor
([[
0
,
2
,
0.3
,
1.5
,
1.5
]],
dtype
=
torch
.
float32
)
...
@@ -527,4 +525,4 @@ class ONNXExporterTester(unittest.TestCase):
...
@@ -527,4 +525,4 @@ class ONNXExporterTester(unittest.TestCase):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unit
test
.
main
()
py
test
.
main
(
[
__file__
]
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment