"vscode:/vscode.git/clone" did not exist on "716286f19ddd9eb417113e064b538706884c8e73"
test_tensorrt_preprocess.py 2.18 KB
Newer Older
limm's avatar
limm committed
1
# Copyright (c) OpenMMLab. All rights reserved.
limm's avatar
limm committed
2
3
4
5
import os
from functools import wraps

import onnx
limm's avatar
limm committed
6
import pytest
limm's avatar
limm committed
7
8
9
10
11
import torch

from mmcv.ops import nms
from mmcv.tensorrt.preprocess import preprocess_onnx

limm's avatar
limm committed
12
13
14
if torch.__version__ == 'parrots':
    pytest.skip('not supported in parrots now', allow_module_level=True)

limm's avatar
limm committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

def remove_tmp_file(func):

    @wraps(func)
    def wrapper(*args, **kwargs):
        onnx_file = 'tmp.onnx'
        kwargs['onnx_file'] = onnx_file
        try:
            result = func(*args, **kwargs)
        finally:
            if os.path.exists(onnx_file):
                os.remove(onnx_file)
        return result

    return wrapper


@remove_tmp_file
def export_nms_module_to_onnx(module, onnx_file):
    torch_model = module()
    torch_model.eval()

    input = (torch.rand([100, 4], dtype=torch.float32),
             torch.rand([100], dtype=torch.float32))

    torch.onnx.export(
        torch_model,
        input,
        onnx_file,
        opset_version=11,
        input_names=['boxes', 'scores'],
        output_names=['output'])

    onnx_model = onnx.load(onnx_file)
    return onnx_model


def test_can_handle_nms_with_constant_maxnum():

    class ModuleNMS(torch.nn.Module):

        def forward(self, boxes, scores):
            return nms(boxes, scores, iou_threshold=0.4, max_num=10)

    onnx_model = export_nms_module_to_onnx(ModuleNMS)
    preprocess_onnx_model = preprocess_onnx(onnx_model)
    for node in preprocess_onnx_model.graph.node:
        if 'NonMaxSuppression' in node.name:
            assert len(node.attribute) == 5, 'The NMS must have 5 attributes.'


def test_can_handle_nms_with_undefined_maxnum():

    class ModuleNMS(torch.nn.Module):

        def forward(self, boxes, scores):
            return nms(boxes, scores, iou_threshold=0.4)

    onnx_model = export_nms_module_to_onnx(ModuleNMS)
    preprocess_onnx_model = preprocess_onnx(onnx_model)
    for node in preprocess_onnx_model.graph.node:
        if 'NonMaxSuppression' in node.name:
            assert len(node.attribute) == 5, \
                'The NMS must have 5 attributes.'
            assert node.attribute[2].i > 0, \
                'The max_output_boxes_per_class is not defined correctly.'