test_mark.py 1.86 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile

import onnx
import torch

from mmdeploy.core import RewriterContext, mark
from mmdeploy.core.optimizers import attribute_to_dict
from mmdeploy.utils.constants import IR, Backend

output_file = tempfile.NamedTemporaryFile(suffix='.onnx').name


def test_mark():

    @mark('add', inputs=['a', 'b'], outputs='c')
    def add(x, y):
        return torch.add(x, y)

    class TestModel(torch.nn.Module):

        def __init__(self):
            super().__init__()

        def forward(self, x, y):
            return add(x, y)

    model = TestModel().eval()

    # dummy input
    x = torch.rand(2, 3, 4)
    y = torch.rand(2, 3, 4)

    torch.onnx.export(model, (x, y), output_file)
    onnx_model = onnx.load(output_file)

    nodes = onnx_model.graph.node
    assert nodes[0].op_type == 'Mark'
    assert nodes[0].domain == 'mmdeploy'
    assert attribute_to_dict(nodes[0].attribute) == dict(
        dtype=1,
        func='add',
        func_id=0,
        id=0,
        type='input',
        name='a',
        shape=[2, 3, 4])

    assert nodes[1].op_type == 'Mark'
    assert nodes[1].domain == 'mmdeploy'
    assert attribute_to_dict(nodes[1].attribute) == dict(
        dtype=1,
        func='add',
        func_id=0,
        id=1,
        type='input',
        name='b',
        shape=[2, 3, 4])

    assert nodes[2].op_type == 'Add'

    assert nodes[3].op_type == 'Mark'
    assert nodes[3].domain == 'mmdeploy'
    assert attribute_to_dict(nodes[3].attribute) == dict(
        dtype=1,
        func='add',
        func_id=0,
        id=0,
        type='output',
        name='c',
        shape=[2, 3, 4])

    with RewriterContext(
            cfg=None, backend=Backend.TORCHSCRIPT.value,
            ir=IR.TORCHSCRIPT), torch.no_grad(), torch.jit.optimized_execution(
                True):
        torch.jit.trace(model, (x, y))