test_onnx2rknn.py 1.66 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile

import mmengine
import pytest
import torch
import torch.nn as nn

from mmdeploy.utils import Backend
from mmdeploy.utils.test import backend_checker

onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
test_img = torch.rand([1, 3, 8, 8])


@pytest.mark.skip(reason='This a not test class but a utility class.')
class TestModel(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * 0.5


test_model = TestModel().eval()


def generate_onnx_file(model):
    with torch.no_grad():
        torch.onnx.export(
            model,
            test_img,
            onnx_file,
            output_names=['output'],
            input_names=['input'],
            keep_initializers_as_inputs=True,
            do_constant_folding=True,
            verbose=False,
            opset_version=11)
        assert osp.exists(onnx_file)


def get_deploy_cfg():
    deploy_cfg = mmengine.Config(
        dict(
            backend_config=dict(
                type='rknn',
                common_config=dict(),
                quantization_config=dict(do_quantization=False, dataset=None),
                input_size_list=[[3, 8, 8]])))
    return deploy_cfg


@backend_checker(Backend.RKNN)
def test_onnx2rknn():
    from mmdeploy.backend.rknn.onnx2rknn import onnx2rknn
    model = test_model
    generate_onnx_file(model)

    work_dir, _ = osp.split(onnx_file)
    rknn_file = onnx_file.replace('.onnx', '.rknn')
    deploy_cfg = get_deploy_cfg()
    onnx2rknn(onnx_file, rknn_file, deploy_cfg)
    assert osp.exists(work_dir)
    assert osp.exists(rknn_file)