"test_modeling_utils.py" did not exist on "3a5c87055c38b1b125029c87dcb1a6d7c4eb6c61"
test_tensorrt_preprocess.py 2.18 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
SemyonBevzuk's avatar
SemyonBevzuk committed
2
3
4
5
import os
from functools import wraps

import onnx
pc's avatar
pc committed
6
import pytest
SemyonBevzuk's avatar
SemyonBevzuk committed
7
8
9
10
11
import torch

from mmcv.ops import nms
from mmcv.tensorrt.preprocess import preprocess_onnx

pc's avatar
pc committed
12
13
14
if torch.__version__ == 'parrots':
    pytest.skip('not supported in parrots now', allow_module_level=True)

SemyonBevzuk's avatar
SemyonBevzuk committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

def remove_tmp_file(func):

    @wraps(func)
    def wrapper(*args, **kwargs):
        onnx_file = 'tmp.onnx'
        kwargs['onnx_file'] = onnx_file
        try:
            result = func(*args, **kwargs)
        finally:
            if os.path.exists(onnx_file):
                os.remove(onnx_file)
        return result

    return wrapper


@remove_tmp_file
def export_nms_module_to_onnx(module, onnx_file):
    torch_model = module()
    torch_model.eval()

    input = (torch.rand([100, 4], dtype=torch.float32),
             torch.rand([100], dtype=torch.float32))

    torch.onnx.export(
        torch_model,
        input,
        onnx_file,
        opset_version=11,
        input_names=['boxes', 'scores'],
        output_names=['output'])

    onnx_model = onnx.load(onnx_file)
    return onnx_model


def test_can_handle_nms_with_constant_maxnum():

    class ModuleNMS(torch.nn.Module):

        def forward(self, boxes, scores):
            return nms(boxes, scores, iou_threshold=0.4, max_num=10)

    onnx_model = export_nms_module_to_onnx(ModuleNMS)
    preprocess_onnx_model = preprocess_onnx(onnx_model)
    for node in preprocess_onnx_model.graph.node:
        if 'NonMaxSuppression' in node.name:
            assert len(node.attribute) == 5, 'The NMS must have 5 attributes.'


def test_can_handle_nms_with_undefined_maxnum():

    class ModuleNMS(torch.nn.Module):

        def forward(self, boxes, scores):
            return nms(boxes, scores, iou_threshold=0.4)

    onnx_model = export_nms_module_to_onnx(ModuleNMS)
    preprocess_onnx_model = preprocess_onnx(onnx_model)
    for node in preprocess_onnx_model.graph.node:
        if 'NonMaxSuppression' in node.name:
            assert len(node.attribute) == 5, \
                'The NMS must have 5 attributes.'
            assert node.attribute[2].i > 0, \
                'The max_output_boxes_per_class is not defined correctly.'