Unverified Commit 23b2bdbf authored by robin Han's avatar robin Han Committed by GitHub
Browse files

add unittest for onnx convert (#608)

* add unittest for onnx convert

* build onnx and onnxruntime in CI

* skip onnx op unit test while using CUDA

* fix offset==0 case in NMS

* remove tmp file used in test

* delete tmp file before assert so that we can remove the tmp file anyway
parent 65a60a3d
...@@ -111,7 +111,7 @@ jobs: ...@@ -111,7 +111,7 @@ jobs:
- name: Install PyTorch - name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install unittest dependencies - name: Install unittest dependencies
run: pip install pytest coverage lmdb PyTurboJPEG run: pip install pytest coverage lmdb PyTurboJPEG onnx==1.6.0 onnxruntime==1.2.0
- name: Build and install - name: Build and install
run: rm -rf .eggs && pip install -e . run: rm -rf .eggs && pip install -e .
- name: Run unittests and generate coverage report - name: Run unittests and generate coverage report
...@@ -181,7 +181,7 @@ jobs: ...@@ -181,7 +181,7 @@ jobs:
run: rm -rf .eggs && pip install -e . run: rm -rf .eggs && pip install -e .
- name: Run unittests and generate coverage report - name: Run unittests and generate coverage report
run: | run: |
coverage run --branch --source=mmcv -m pytest tests/ coverage run --branch --source=mmcv -m pytest tests/ --ignore=tests/test_ops/test_onnx.py
coverage xml coverage xml
coverage report -m coverage report -m
# Only upload coverage report for python3.7 && pytorch1.5 # Only upload coverage report for python3.7 && pytorch1.5
...@@ -220,6 +220,8 @@ jobs: ...@@ -220,6 +220,8 @@ jobs:
- name: Install Pillow - name: Install Pillow
run: pip install Pillow==6.2.2 run: pip install Pillow==6.2.2
if: ${{matrix.torchvision == '0.4.2'}} if: ${{matrix.torchvision == '0.4.2'}}
- name: Install ONNX
run: pip install onnx==1.6.0 onnxruntime==1.2.0
- name: Install PyTorch - name: Install PyTorch
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} --no-cache-dir run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} --no-cache-dir
- name: Build and install - name: Build and install
......
...@@ -111,6 +111,9 @@ def nms(boxes, scores, iou_threshold, offset=0): ...@@ -111,6 +111,9 @@ def nms(boxes, scores, iou_threshold, offset=0):
# ONNX only support offset == 1 # ONNX only support offset == 1
boxes[:, -2:] -= 1 boxes[:, -2:] -= 1
inds = NMSop.apply(boxes, scores, iou_threshold, offset) inds = NMSop.apply(boxes, scores, iou_threshold, offset)
if torch.onnx.is_in_onnx_export() and offset == 0:
# ONNX only support offset == 1
boxes[:, -2:] += 1
dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1) dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
if is_numpy: if is_numpy:
dets = dets.cpu().numpy() dets = dets.cpu().numpy()
......
...@@ -14,6 +14,6 @@ line_length = 79 ...@@ -14,6 +14,6 @@ line_length = 79
multi_line_output = 0 multi_line_output = 0
known_standard_library = pkg_resources,setuptools known_standard_library = pkg_resources,setuptools
known_first_party = mmcv known_first_party = mmcv
known_third_party = Cython,addict,cv2,m2r,numpy,pytest,recommonmark,resnet_cifar,torch,torchvision,yaml,yapf known_third_party = Cython,addict,cv2,m2r,numpy,onnx,onnxruntime,pytest,recommonmark,resnet_cifar,torch,torchvision,yaml,yapf
no_lines_before = STDLIB,LOCALFOLDER no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY default_section = THIRDPARTY
import os
from functools import partial
import numpy as np
import onnx
import onnxruntime as rt
import torch
import torch.nn as nn
onnx_file = 'tmp.onnx'
class WrapFunction(nn.Module):
def __init__(self, wrapped_function):
super(WrapFunction, self).__init__()
self.wrapped_function = wrapped_function
def forward(self, *args, **kwargs):
return self.wrapped_function(*args, **kwargs)
class Testonnx(object):
def test_nms(self):
from mmcv.ops import nms
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
pytorch_dets, _ = nms(boxes, scores, iou_threshold=0.3, offset=0)
pytorch_score = pytorch_dets[:, 4]
nms = partial(nms, iou_threshold=0.3, offset=0)
wrapped_model = WrapFunction(nms)
wrapped_model.cpu().eval()
with torch.no_grad():
torch.onnx.export(
wrapped_model, (boxes, scores),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['boxes', 'scores'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file)
onnx_dets, _ = sess.run(None, {
'scores': scores.detach().numpy(),
'boxes': boxes.detach().numpy()
})
onnx_score = onnx_dets[:, 4]
os.remove(onnx_file)
assert np.allclose(pytorch_score, onnx_score, atol=1e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment