test_onnx.py 7.39 KB
Newer Older
1
2
3
import io
import torch
from torchvision import ops
4
from torchvision.models.detection.image_list import ImageList
5
from torchvision.models.detection.transform import GeneralizedRCNNTransform
6
7
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
8

9
10
from collections import OrderedDict

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# onnxruntime requires python 3.5 or above
try:
    import onnxruntime
except ImportError:
    onnxruntime = None

import unittest


@unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable')
class ONNXExporterTester(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        torch.manual_seed(123)

26
    def run_model(self, model, inputs_list, tolerate_small_mismatch=False):
27
28
29
        model.eval()

        onnx_io = io.BytesIO()
30
31
        # export to onnx with the first input
        torch.onnx.export(model, inputs_list[0], onnx_io, do_constant_folding=True, opset_version=10)
32
33

        # validate the exported model with onnx runtime
34
35
36
37
38
39
40
41
        for test_inputs in inputs_list:
            with torch.no_grad():
                if isinstance(test_inputs, torch.Tensor) or \
                   isinstance(test_inputs, list):
                    test_inputs = (test_inputs,)
                test_ouputs = model(*test_inputs)
                if isinstance(test_ouputs, torch.Tensor):
                    test_ouputs = (test_ouputs,)
42
            self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)
43

44
    def ort_validate(self, onnx_io, inputs, outputs, tolerate_small_mismatch=False):
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

        inputs, _ = torch.jit._flatten(inputs)
        outputs, _ = torch.jit._flatten(outputs)

        def to_numpy(tensor):
            if tensor.requires_grad:
                return tensor.detach().cpu().numpy()
            else:
                return tensor.cpu().numpy()

        inputs = list(map(to_numpy, inputs))
        outputs = list(map(to_numpy, outputs))

        ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
        # compute onnxruntime output prediction
        ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))
        ort_outs = ort_session.run(None, ort_inputs)

        for i in range(0, len(outputs)):
64
65
66
67
68
69
70
            try:
                torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
            except AssertionError as error:
                if tolerate_small_mismatch:
                    assert ("(0.00%)" in str(error)), str(error)
                else:
                    assert False, str(error)
71
72
73
74
75
76
77
78
79
80

    def test_nms(self):
        boxes = torch.rand(5, 4)
        boxes[:, 2:] += torch.rand(5, 2)
        scores = torch.randn(5)

        class Module(torch.nn.Module):
            def forward(self, boxes, scores):
                return ops.nms(boxes, scores, 0.5)

81
        self.run_model(Module(), [(boxes, scores)])
82

83
    def test_roi_align(self):
84
85
86
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, 2)
87
        self.run_model(model, [(x, single_roi)])
88

89
    def test_roi_pool(self):
90
91
92
93
94
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
        pool_h = 5
        pool_w = 5
        model = ops.RoIPool((pool_h, pool_w), 2)
95
96
97
98
99
100
101
102
        self.run_model(model, [(x, rois)])

    @unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
    def test_transform_images(self):

        class TransformModule(torch.nn.Module):
            def __init__(self_module):
                super(TransformModule, self_module).__init__()
103
                self_module.transform = self._init_test_generalized_rcnn_transform()
104
105
106
107
108
109
110

            def forward(self_module, images):
                return self_module.transform(images)[0].tensors

        input = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)]
        input_test = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)]
        self.run_model(TransformModule(), [input, input_test])
111

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    def _init_test_generalized_rcnn_transform(self):
        min_size = 800
        max_size = 1333
        image_mean = [0.485, 0.456, 0.406]
        image_std = [0.229, 0.224, 0.225]
        transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
        return transform

    def _init_test_rpn(self):
        anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
        aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
        rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
        out_channels = 256
        rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
        rpn_fg_iou_thresh = 0.7
        rpn_bg_iou_thresh = 0.3
        rpn_batch_size_per_image = 256
        rpn_positive_fraction = 0.5
        rpn_pre_nms_top_n = dict(training=2000, testing=1000)
        rpn_post_nms_top_n = dict(training=2000, testing=1000)
        rpn_nms_thresh = 0.7

        rpn = RegionProposalNetwork(
            rpn_anchor_generator, rpn_head,
            rpn_fg_iou_thresh, rpn_bg_iou_thresh,
            rpn_batch_size_per_image, rpn_positive_fraction,
            rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
        return rpn

    def test_rpn(self):
        class RPNModule(torch.nn.Module):
            def __init__(self_module, images):
                super(RPNModule, self_module).__init__()
                self_module.rpn = self._init_test_rpn()
                self_module.images = ImageList(images, [i.shape[-2:] for i in images])

            def forward(self_module, features):
                return self_module.rpn(self_module.images, features)

        def get_features(images):
            s0, s1 = images.shape[-2:]
            features = [
                ('0', torch.rand(2, 256, s0 // 4, s1 // 4)),
                ('1', torch.rand(2, 256, s0 // 8, s1 // 8)),
                ('2', torch.rand(2, 256, s0 // 16, s1 // 16)),
                ('3', torch.rand(2, 256, s0 // 32, s1 // 32)),
                ('4', torch.rand(2, 256, s0 // 64, s1 // 64)),
            ]
            features = OrderedDict(features)
            return features

        images = torch.rand(2, 3, 600, 600)
        features = get_features(images)
        test_features = get_features(images)

        model = RPNModule(images)
        model.eval()
        model(features)
        self.run_model(model, [(features,), (test_features,)], tolerate_small_mismatch=True)

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    def test_multi_scale_roi_align(self):

        class TransformModule(torch.nn.Module):
            def __init__(self):
                super(TransformModule, self).__init__()
                self.model = ops.MultiScaleRoIAlign(['feat1', 'feat2'], 3, 2)
                self.image_sizes = [(512, 512)]

            def forward(self, input, boxes):
                return self.model(input, boxes, self.image_sizes)

        i = OrderedDict()
        i['feat1'] = torch.rand(1, 5, 64, 64)
        i['feat2'] = torch.rand(1, 5, 16, 16)
        boxes = torch.rand(6, 4) * 256
        boxes[:, 2:] += boxes[:, :2]

        i1 = OrderedDict()
        i1['feat1'] = torch.rand(1, 5, 64, 64)
        i1['feat2'] = torch.rand(1, 5, 16, 16)
        boxes1 = torch.rand(6, 4) * 256
        boxes1[:, 2:] += boxes1[:, :2]

        self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)])

197
198
199

if __name__ == '__main__':
    unittest.main()