test_wrapper.py 7.39 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import subprocess
import tempfile

import mmengine
import pytest
import torch
import torch.nn as nn

from mmdeploy.utils.constants import Backend
from mmdeploy.utils.test import check_backend

onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
ts_file = tempfile.NamedTemporaryFile(suffix='.pt').name
test_img = torch.rand(1, 3, 8, 8)
output_names = ['output']
input_names = ['input']
target_platform = 'rk3588'  # rknn pre-compiled model need device


@pytest.mark.skip(reason='This a not test class but a utility class.')
class TestModel(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x + test_img


model = TestModel().eval()


@pytest.fixture(autouse=True, scope='module')
def generate_onnx_file():
    with torch.no_grad():
        torch.onnx.export(
            model,
            test_img,
            onnx_file,
            output_names=output_names,
            input_names=input_names,
            keep_initializers_as_inputs=True,
            do_constant_folding=True,
            verbose=False,
            opset_version=11,
            dynamic_axes=None)


@pytest.fixture(autouse=True, scope='module')
def generate_torchscript_file():
    from mmengine import Config

    backend = Backend.TORCHSCRIPT.value
    deploy_cfg = Config({'backend_config': dict(type=backend)})

    from mmdeploy.apis.torch_jit import trace
    context_info = dict(deploy_cfg=deploy_cfg)
    output_prefix = osp.splitext(ts_file)[0]

    example_inputs = torch.rand(1, 3, 8, 8)
    trace(
        model,
        example_inputs,
        output_path_prefix=output_prefix,
        backend=backend,
        context_info=context_info)


def ir2backend(backend, onnx_file, ts_file):
    if backend == Backend.TENSORRT:
        from mmdeploy.backend.tensorrt import from_onnx
        backend_file = tempfile.NamedTemporaryFile(suffix='.engine').name
        from_onnx(
            onnx_file,
            osp.splitext(backend_file)[0], {
                'input': {
                    'min_shape': [1, 3, 8, 8],
                    'opt_shape': [1, 3, 8, 8],
                    'max_shape': [1, 3, 8, 8]
                }
            })
        return backend_file
    elif backend == Backend.ONNXRUNTIME:
        return onnx_file
    elif backend == Backend.PPLNN:
        from mmdeploy.apis.pplnn import from_onnx
        output_file_prefix = tempfile.NamedTemporaryFile().name
        from_onnx(onnx_file, output_file_prefix=output_file_prefix)
        algo_file = output_file_prefix + '.json'
        output_file = output_file_prefix + '.onnx'
        return output_file, algo_file
    elif backend == Backend.NCNN:
        from mmdeploy.backend.ncnn.init_plugins import get_onnx2ncnn_path
        onnx2ncnn_path = get_onnx2ncnn_path()
        param_file = tempfile.NamedTemporaryFile(suffix='.param').name
        bin_file = tempfile.NamedTemporaryFile(suffix='.bin').name
        subprocess.call([onnx2ncnn_path, onnx_file, param_file, bin_file])
        return param_file, bin_file
    elif backend == Backend.OPENVINO:
        from mmdeploy.apis.openvino import from_onnx, get_output_model_file
        backend_dir = tempfile.TemporaryDirectory().name
        backend_file = get_output_model_file(onnx_file, backend_dir)
        input_info = {'input': test_img.shape}
        output_names = ['output']
        work_dir = backend_dir
        from_onnx(onnx_file, work_dir, input_info, output_names)
        return backend_file
    elif backend == Backend.RKNN:
        from mmdeploy.apis.rknn import onnx2rknn
        rknn_file = onnx_file.replace('.onnx', '.rknn')
        deploy_cfg = mmengine.Config(
            dict(
                backend_config=dict(
                    type='rknn',
                    common_config=dict(target_platform=target_platform),
                    quantization_config=dict(
                        do_quantization=False, dataset=None),
                    input_size_list=[[3, 8, 8]])))
        onnx2rknn(onnx_file, rknn_file, deploy_cfg)
        return rknn_file
    elif backend == Backend.ASCEND:
        from mmdeploy.apis.ascend import from_onnx
        backend_dir = tempfile.TemporaryDirectory().name
        work_dir = backend_dir
        file_name = osp.splitext(osp.split(onnx_file)[1])[0]
        backend_file = osp.join(work_dir, file_name + '.om')
        model_inputs = mmengine.Config(
            dict(input_shapes=dict(input=test_img.shape)))
        from_onnx(onnx_file, work_dir, model_inputs)
        return backend_file
    elif backend == Backend.TVM:
        from mmdeploy.backend.tvm import from_onnx, get_library_ext
        ext = get_library_ext()
        lib_file = tempfile.NamedTemporaryFile(suffix=ext).name
        shape = {'input': test_img.shape}
        dtype = {'input': 'float32'}
        target = 'llvm'
        tuner_dict = dict(type='DefaultTuner', target=target)
        from_onnx(
            onnx_file, lib_file, shape=shape, dtype=dtype, tuner=tuner_dict)
        assert osp.exists(lib_file)
        return lib_file
    elif backend == Backend.TORCHSCRIPT:
        return ts_file
    elif backend == Backend.COREML:
        output_names = ['output']
        from mmdeploy.backend.coreml.torchscript2coreml import (
            from_torchscript, get_model_suffix)
        backend_dir = tempfile.TemporaryDirectory().name
        work_dir = backend_dir
        torchscript_name = osp.splitext(osp.split(ts_file)[1])[0]
        output_file_prefix = osp.join(work_dir, torchscript_name)
        convert_to = 'mlprogram'
        from_torchscript(
            ts_file,
            output_file_prefix,
            input_names=input_names,
            output_names=output_names,
            input_shapes=dict(
                input=dict(
                    min_shape=[1, 3, 8, 8],
                    default_shape=[1, 3, 8, 8],
                    max_shape=[1, 3, 8, 8])),
            convert_to=convert_to)
        suffix = get_model_suffix(convert_to)
        return output_file_prefix + suffix
    else:
        raise NotImplementedError(
            f'Convert for {backend.value} has not been implemented.')


def create_wrapper(backend, model_files):
    from mmdeploy.backend.base import get_backend_manager
    backend_mgr = get_backend_manager(backend.value)
    deploy_cfg = None
    if isinstance(model_files, str):
        model_files = [model_files]
    elif backend == Backend.RKNN:
        deploy_cfg = dict(
            backend_config=dict(
                common_config=dict(target_platform=target_platform)))
    return backend_mgr.build_wrapper(
        model_files,
        input_names=input_names,
        output_names=output_names,
        deploy_cfg=deploy_cfg)


def run_wrapper(backend, wrapper, input):
    if backend == Backend.TENSORRT:
        input = input.cuda()

    results = wrapper({'input': input})

    if backend != Backend.RKNN:
        results = results['output']

    results = results.detach().cpu()

    return results


ALL_BACKEND = list(Backend)
ALL_BACKEND.remove(Backend.DEFAULT)
ALL_BACKEND.remove(Backend.PYTORCH)
ALL_BACKEND.remove(Backend.SDK)


@pytest.mark.parametrize('backend', ALL_BACKEND)
def test_wrapper(backend):
    check_backend(backend, True)
    model_files = ir2backend(backend, onnx_file, ts_file)
    assert model_files is not None
    wrapper = create_wrapper(backend, model_files)
    assert wrapper is not None
    results = run_wrapper(backend, wrapper, test_img)
    assert results is not None