Unverified Commit fb2598b8 authored by Zhiqiang Wang's avatar Zhiqiang Wang Committed by GitHub
Browse files

Port test_onnx.py to pytest (#4047)

parent 552a4060
......@@ -257,6 +257,7 @@ jobs:
pip install --user --progress-bar off --editable .
pip install --user onnx
pip install --user onnxruntime
pip install --user pytest
python test/test_onnx.py
binary_linux_wheel:
......
......@@ -257,6 +257,7 @@ jobs:
pip install --user --progress-bar off --editable .
pip install --user onnx
pip install --user onnxruntime
pip install --user pytest
python test/test_onnx.py
binary_linux_wheel:
......
......@@ -15,21 +15,19 @@ from torchvision import models
from torchvision.models.detection.image_list import ImageList
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from torchvision.models.detection.mask_rcnn import MaskRCNNHeads, MaskRCNNPredictor
from collections import OrderedDict
import unittest
import pytest
from torchvision.ops._register_onnx_ops import _onnx_opset_version
@unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable')
class ONNXExporterTester(unittest.TestCase):
@pytest.mark.skipif(onnxruntime is None, reason='ONNX Runtime unavailable')
class TestONNXExporter:
@classmethod
def setUpClass(cls):
def setup_class(cls):
torch.manual_seed(123)
def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None,
......@@ -80,7 +78,7 @@ class ONNXExporterTester(unittest.TestCase):
torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
except AssertionError as error:
if tolerate_small_mismatch:
self.assertIn("(0.00%)", str(error), str(error))
assert "(0.00%)" in str(error), str(error)
else:
raise
......@@ -161,7 +159,7 @@ class ONNXExporterTester(unittest.TestCase):
model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
self.run_model(model, [(x, single_roi)])
@unittest.skip # Issue in exporting ROIAlign with aligned = True for malformed boxes
@pytest.mark.skip(reason="Issue in exporting ROIAlign with aligned = True for malformed boxes")
def test_roi_align_malformed_boxes(self):
x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
......@@ -527,4 +525,4 @@ class ONNXExporterTester(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment