test_nms_match_small.py 4.01 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (c) OpenMMLab. All rights reserved.
import os
import tempfile

import numpy
import onnxruntime
import pytest
import torch

from mmdeploy.backend.onnxruntime.init_plugins import get_ops_path
from mmdeploy.mmcv.ops import ONNXNMSMatchOp

cur_dir = os.path.dirname(os.path.abspath(__file__))
boxes = torch.tensor([
    [
        [291.1746, 316.2263, 343.5029, 347.7312],
        [288.4846, 315.0447, 343.7267, 346.5630],
        [288.5307, 318.1989, 341.6425, 349.7222],
        [918.9102, 83.7463, 933.3920, 164.9041],
        [895.5786, 78.2361, 907.8049, 172.0883],
        [292.5816, 316.5563, 340.3462, 352.9989],
        [609.4592, 83.5447, 631.2532, 144.0749],
        [917.7308, 85.5870, 933.2839, 168.4530],
        [895.5138, 79.3596, 908.2865, 171.0418],
        [291.4747, 318.6987, 347.1208, 349.5754],
    ],
    [
        [291.1746, 316.2263, 343.5029, 347.7312],
        [288.4846, 315.0447, 343.7267, 346.5630],
        [288.5307, 318.1989, 341.6425, 349.7222],
        [918.9102, 83.7463, 933.3920, 164.9041],
        [895.5786, 78.2361, 907.8049, 172.0883],
        [292.5816, 316.5563, 340.3462, 352.9989],
        [609.4592, 83.5447, 631.2532, 144.0749],
        [917.7308, 85.5870, 933.2839, 168.4530],
        [895.5138, 79.3596, 908.2865, 171.0418],
        [291.4747, 318.6987, 347.1208, 349.5754],
    ],
])
scores = torch.tensor([
    [
        [0.9577, 0.9745, 0.3030, 0.6589, 0.2742],
        [0.1618, 0.7963, 0.5124, 0.6964, 0.6850],
        [0.8425, 0.4843, 0.9489, 0.8068, 0.7340],
        [0.7337, 0.4340, 0.9923, 0.0704, 0.4506],
        [0.3090, 0.5606, 0.6939, 0.3764, 0.6920],
        [0.0044, 0.7986, 0.2221, 0.2782, 0.4378],
        [0.7293, 0.2735, 0.8381, 0.0264, 0.6278],
        [0.7144, 0.1066, 0.4125, 0.4041, 0.8819],
        [0.4963, 0.7891, 0.6908, 0.1499, 0.5584],
        [0.4385, 0.6035, 0.0508, 0.0662, 0.5938],
    ],
    [
        [0.9577, 0.9745, 0.3030, 0.6589, 0.2742],
        [0.1618, 0.7963, 0.5124, 0.6964, 0.6850],
        [0.8425, 0.4843, 0.9489, 0.8068, 0.7340],
        [0.7337, 0.4340, 0.9923, 0.0704, 0.4506],
        [0.3090, 0.5606, 0.6939, 0.3764, 0.6920],
        [0.0044, 0.7986, 0.2221, 0.2782, 0.4378],
        [0.7293, 0.2735, 0.8381, 0.0264, 0.6278],
        [0.7144, 0.1066, 0.4125, 0.4041, 0.8819],
        [0.4963, 0.7891, 0.6908, 0.1499, 0.5584],
        [0.4385, 0.6035, 0.0508, 0.0662, 0.5938],
    ],
])
scores = scores.permute(0, 2, 1)
iou_threshold = torch.tensor([0.1])
score_threshold = torch.tensor([0.1])
match_op = ONNXNMSMatchOp.apply


class test_ONNX_Match(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, boxes, scores, iou_threshold, score_threshold):
        return match_op(boxes, scores, iou_threshold, score_threshold)


@pytest.mark.skipif(
    reason='Need to build onnxrumtime custom op',
    condition=get_ops_path() == '')
def test_nms_match():
    print('Running compilation...')
    # here is a PyTorch test
    model = test_ONNX_Match()
    torch_output = model(boxes, scores, iou_threshold,
                         score_threshold).detach().numpy()
    # export the onnx file with a tempfile
    temp_onnx = tempfile.NamedTemporaryFile(
        suffix='.onnx', delete=False, mode='wb', dir=cur_dir)
    input_name = ['boxes', 'scores', 'iou_thr', 'score_thr']
    torch.onnx.export(
        model,
        (boxes, scores, iou_threshold, score_threshold),
        temp_onnx.name,
        input_names=input_name,
    )
    temp_onnx.close()
    options = onnxruntime.SessionOptions()
    options.register_custom_ops_library(get_ops_path())

    sess = onnxruntime.InferenceSession(
        temp_onnx.name, options, providers=['CPUExecutionProvider'])
    ort_output = sess.run(
        None,
        {
            'boxes': boxes.numpy(),
            'scores': scores.numpy(),
            'iou_thr': iou_threshold.numpy(),
            'score_thr': score_threshold.numpy(),
        },
    )

    assert numpy.array_equal(
        numpy.array(torch_output),
        numpy.array(ort_output[0])), 'list are not equal'
    os.remove(temp_onnx.name)