".gitattributes" did not exist on "2dce1ab40bf56d10145762249612c5e11af08541"
test_onnx.py 2.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
from functools import partial

import numpy as np
import onnx
import onnxruntime as rt
import torch
import torch.nn as nn

onnx_file = 'tmp.onnx'


class WrapFunction(nn.Module):

    def __init__(self, wrapped_function):
        super(WrapFunction, self).__init__()
        self.wrapped_function = wrapped_function

    def forward(self, *args, **kwargs):
        return self.wrapped_function(*args, **kwargs)


class Testonnx(object):

    def test_nms(self):
        from mmcv.ops import nms
        np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
                             [3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
                            dtype=np.float32)
        np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
        boxes = torch.from_numpy(np_boxes)
        scores = torch.from_numpy(np_scores)
        pytorch_dets, _ = nms(boxes, scores, iou_threshold=0.3, offset=0)
        pytorch_score = pytorch_dets[:, 4]
        nms = partial(nms, iou_threshold=0.3, offset=0)
        wrapped_model = WrapFunction(nms)
        wrapped_model.cpu().eval()
        with torch.no_grad():
            torch.onnx.export(
                wrapped_model, (boxes, scores),
                onnx_file,
                export_params=True,
                keep_initializers_as_inputs=True,
                input_names=['boxes', 'scores'],
                opset_version=11)
        onnx_model = onnx.load(onnx_file)

        # get onnx output
        input_all = [node.name for node in onnx_model.graph.input]
        input_initializer = [
            node.name for node in onnx_model.graph.initializer
        ]
        net_feed_input = list(set(input_all) - set(input_initializer))
        assert (len(net_feed_input) == 2)
        sess = rt.InferenceSession(onnx_file)
        onnx_dets, _ = sess.run(None, {
            'scores': scores.detach().numpy(),
            'boxes': boxes.detach().numpy()
        })
        onnx_score = onnx_dets[:, 4]
        os.remove(onnx_file)
        assert np.allclose(pytorch_score, onnx_score, atol=1e-3)