test_onnx.py 21.8 KB
Newer Older
1
2
3
4
5
6
7
8
# onnxruntime requires python 3.5 or above
try:
    # This import should be before that of torch
    # see https://github.com/onnx/onnx/issues/2394#issuecomment-581638840
    import onnxruntime
except ImportError:
    onnxruntime = None

9
from common_utils import set_rng_seed, assert_equal
10
11
12
import io
import torch
from torchvision import ops
13
from torchvision import models
14
from torchvision.models.detection.image_list import ImageList
15
from torchvision.models.detection.transform import GeneralizedRCNNTransform
16
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
17
18
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
19

20
from collections import OrderedDict
21
from typing import List, Tuple
22

23
import pytest
24
from torchvision.ops._register_onnx_ops import _onnx_opset_version
25
26


27
28
@pytest.mark.skipif(onnxruntime is None, reason='ONNX Runtime unavailable')
class TestONNXExporter:
29
    @classmethod
30
    def setup_class(cls):
31
32
        torch.manual_seed(123)

33
34
    def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None,
                  output_names=None, input_names=None):
35
36
37
        model.eval()

        onnx_io = io.BytesIO()
38
39
40
41
        if isinstance(inputs_list[0][-1], dict):
            torch_onnx_input = inputs_list[0] + ({},)
        else:
            torch_onnx_input = inputs_list[0]
42
        # export to onnx with the first input
43
        torch.onnx.export(model, torch_onnx_input, onnx_io,
44
45
                          do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version,
                          dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names)
46
        # validate the exported model with onnx runtime
47
48
49
50
51
52
53
54
        for test_inputs in inputs_list:
            with torch.no_grad():
                if isinstance(test_inputs, torch.Tensor) or \
                   isinstance(test_inputs, list):
                    test_inputs = (test_inputs,)
                test_ouputs = model(*test_inputs)
                if isinstance(test_ouputs, torch.Tensor):
                    test_ouputs = (test_ouputs,)
55
            self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)
56

57
    def ort_validate(self, onnx_io, inputs, outputs, tolerate_small_mismatch=False):
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

        inputs, _ = torch.jit._flatten(inputs)
        outputs, _ = torch.jit._flatten(outputs)

        def to_numpy(tensor):
            if tensor.requires_grad:
                return tensor.detach().cpu().numpy()
            else:
                return tensor.cpu().numpy()

        inputs = list(map(to_numpy, inputs))
        outputs = list(map(to_numpy, outputs))

        ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
        # compute onnxruntime output prediction
        ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))
        ort_outs = ort_session.run(None, ort_inputs)
75

76
        for i in range(0, len(outputs)):
77
78
79
80
            try:
                torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
            except AssertionError as error:
                if tolerate_small_mismatch:
81
                    assert "(0.00%)" in str(error), str(error)
82
                else:
83
                    raise
84
85

    def test_nms(self):
86
87
88
89
        num_boxes = 100
        boxes = torch.rand(num_boxes, 4)
        boxes[:, 2:] += boxes[:, :2]
        scores = torch.randn(num_boxes)
90
91
92
93
94

        class Module(torch.nn.Module):
            def forward(self, boxes, scores):
                return ops.nms(boxes, scores, 0.5)

95
        self.run_model(Module(), [(boxes, scores)])
96

97
98
99
100
101
102
103
104
105
106
107
108
109
    def test_batched_nms(self):
        num_boxes = 100
        boxes = torch.rand(num_boxes, 4)
        boxes[:, 2:] += boxes[:, :2]
        scores = torch.randn(num_boxes)
        idxs = torch.randint(0, 5, size=(num_boxes,))

        class Module(torch.nn.Module):
            def forward(self, boxes, scores, idxs):
                return ops.batched_nms(boxes, scores, idxs, 0.5)

        self.run_model(Module(), [(boxes, scores, idxs)])

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    def test_clip_boxes_to_image(self):
        boxes = torch.randn(5, 4) * 500
        boxes[:, 2:] += boxes[:, :2]
        size = torch.randn(200, 300)

        size_2 = torch.randn(300, 400)

        class Module(torch.nn.Module):
            def forward(self, boxes, size):
                return ops.boxes.clip_boxes_to_image(boxes, size.shape)

        self.run_model(Module(), [(boxes, size), (boxes, size_2)],
                       input_names=["boxes", "size"],
                       dynamic_axes={"size": [0, 1]})

125
    def test_roi_align(self):
126
127
128
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, 2)
129
        self.run_model(model, [(x, single_roi)])
130

131
132
133
134
135
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, -1)
        self.run_model(model, [(x, single_roi)])

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    def test_roi_align_aligned(self):
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, 2, aligned=True)
        self.run_model(model, [(x, single_roi)])

        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
        self.run_model(model, [(x, single_roi)])

        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
        self.run_model(model, [(x, single_roi)])

        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
        self.run_model(model, [(x, single_roi)])

157
158
159
160
161
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
        self.run_model(model, [(x, single_roi)])

162
    @pytest.mark.skip(reason="Issue in exporting ROIAlign with aligned = True for malformed boxes")
163
164
165
166
167
168
    def test_roi_align_malformed_boxes(self):
        x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, 1, aligned=True)
        self.run_model(model, [(x, single_roi)])

169
    def test_roi_pool(self):
170
171
172
173
174
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
        pool_h = 5
        pool_w = 5
        model = ops.RoIPool((pool_h, pool_w), 2)
175
176
        self.run_model(model, [(x, rois)])

177
178
179
180
181
182
183
184
185
186
187
188
    def test_resize_images(self):
        class TransformModule(torch.nn.Module):
            def __init__(self_module):
                super(TransformModule, self_module).__init__()
                self_module.transform = self._init_test_generalized_rcnn_transform()

            def forward(self_module, images):
                return self_module.transform.resize(images, None)[0]

        input = torch.rand(3, 10, 20)
        input_test = torch.rand(3, 100, 150)
        self.run_model(TransformModule(), [(input,), (input_test,)],
189
                       input_names=["input1"], dynamic_axes={"input1": [0, 1, 2]})
190

191
192
193
194
195
    def test_transform_images(self):

        class TransformModule(torch.nn.Module):
            def __init__(self_module):
                super(TransformModule, self_module).__init__()
196
                self_module.transform = self._init_test_generalized_rcnn_transform()
197
198
199
200

            def forward(self_module, images):
                return self_module.transform(images)[0].tensors

201
202
203
        input = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
        input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
        self.run_model(TransformModule(), [(input,), (input_test,)])
204

205
    def _init_test_generalized_rcnn_transform(self):
206
207
        min_size = 100
        max_size = 200
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        image_mean = [0.485, 0.456, 0.406]
        image_std = [0.229, 0.224, 0.225]
        transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
        return transform

    def _init_test_rpn(self):
        anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
        aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
        rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
        out_channels = 256
        rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
        rpn_fg_iou_thresh = 0.7
        rpn_bg_iou_thresh = 0.3
        rpn_batch_size_per_image = 256
        rpn_positive_fraction = 0.5
        rpn_pre_nms_top_n = dict(training=2000, testing=1000)
        rpn_post_nms_top_n = dict(training=2000, testing=1000)
        rpn_nms_thresh = 0.7
226
        rpn_score_thresh = 0.0
227
228
229
230
231

        rpn = RegionProposalNetwork(
            rpn_anchor_generator, rpn_head,
            rpn_fg_iou_thresh, rpn_bg_iou_thresh,
            rpn_batch_size_per_image, rpn_positive_fraction,
232
233
            rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh,
            score_thresh=rpn_score_thresh)
234
235
        return rpn

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    def _init_test_roi_heads_faster_rcnn(self):
        out_channels = 256
        num_classes = 91

        box_fg_iou_thresh = 0.5
        box_bg_iou_thresh = 0.5
        box_batch_size_per_image = 512
        box_positive_fraction = 0.25
        bbox_reg_weights = None
        box_score_thresh = 0.05
        box_nms_thresh = 0.5
        box_detections_per_img = 100

        box_roi_pool = ops.MultiScaleRoIAlign(
            featmap_names=['0', '1', '2', '3'],
            output_size=7,
            sampling_ratio=2)

        resolution = box_roi_pool.output_size[0]
        representation_size = 1024
        box_head = TwoMLPHead(
            out_channels * resolution ** 2,
            representation_size)

        representation_size = 1024
        box_predictor = FastRCNNPredictor(
            representation_size,
            num_classes)

        roi_heads = RoIHeads(
            box_roi_pool, box_head, box_predictor,
            box_fg_iou_thresh, box_bg_iou_thresh,
            box_batch_size_per_image, box_positive_fraction,
            bbox_reg_weights,
            box_score_thresh, box_nms_thresh, box_detections_per_img)
        return roi_heads

    def get_features(self, images):
        s0, s1 = images.shape[-2:]
        features = [
            ('0', torch.rand(2, 256, s0 // 4, s1 // 4)),
            ('1', torch.rand(2, 256, s0 // 8, s1 // 8)),
            ('2', torch.rand(2, 256, s0 // 16, s1 // 16)),
            ('3', torch.rand(2, 256, s0 // 32, s1 // 32)),
            ('4', torch.rand(2, 256, s0 // 64, s1 // 64)),
        ]
        features = OrderedDict(features)
        return features

285
    def test_rpn(self):
286
287
        set_rng_seed(0)

288
        class RPNModule(torch.nn.Module):
289
            def __init__(self_module):
290
291
292
                super(RPNModule, self_module).__init__()
                self_module.rpn = self._init_test_rpn()

293
294
295
            def forward(self_module, images, features):
                images = ImageList(images, [i.shape[-2:] for i in images])
                return self_module.rpn(images, features)
296

297
        images = torch.rand(2, 3, 150, 150)
298
        features = self.get_features(images)
299
300
        images2 = torch.rand(2, 3, 80, 80)
        test_features = self.get_features(images2)
301

302
        model = RPNModule()
303
        model.eval()
304
305
306
307
308
309
310
        model(images, features)

        self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True,
                       input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
                       dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3],
                                     "input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3],
                                     "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})
311

312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    def test_multi_scale_roi_align(self):

        class TransformModule(torch.nn.Module):
            def __init__(self):
                super(TransformModule, self).__init__()
                self.model = ops.MultiScaleRoIAlign(['feat1', 'feat2'], 3, 2)
                self.image_sizes = [(512, 512)]

            def forward(self, input, boxes):
                return self.model(input, boxes, self.image_sizes)

        i = OrderedDict()
        i['feat1'] = torch.rand(1, 5, 64, 64)
        i['feat2'] = torch.rand(1, 5, 16, 16)
        boxes = torch.rand(6, 4) * 256
        boxes[:, 2:] += boxes[:, :2]

        i1 = OrderedDict()
        i1['feat1'] = torch.rand(1, 5, 64, 64)
        i1['feat2'] = torch.rand(1, 5, 16, 16)
        boxes1 = torch.rand(6, 4) * 256
        boxes1[:, 2:] += boxes1[:, :2]

        self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)])

337
338
    def test_roi_heads(self):
        class RoiHeadsModule(torch.nn.Module):
339
            def __init__(self_module):
340
341
342
343
344
                super(RoiHeadsModule, self_module).__init__()
                self_module.transform = self._init_test_generalized_rcnn_transform()
                self_module.rpn = self._init_test_rpn()
                self_module.roi_heads = self._init_test_roi_heads_faster_rcnn()

345
346
347
348
349
            def forward(self_module, images, features):
                original_image_sizes = [img.shape[-2:] for img in images]
                images = ImageList(images, [i.shape[-2:] for i in images])
                proposals, _ = self_module.rpn(images, features)
                detections, _ = self_module.roi_heads(features, proposals, images.image_sizes)
350
                detections = self_module.transform.postprocess(detections,
351
352
                                                               images.image_sizes,
                                                               original_image_sizes)
353
354
                return detections

355
        images = torch.rand(2, 3, 100, 100)
356
        features = self.get_features(images)
357
358
        images2 = torch.rand(2, 3, 150, 150)
        test_features = self.get_features(images2)
359

360
        model = RoiHeadsModule()
361
        model.eval()
362
        model(images, features)
363

364
365
366
367
368
        self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True,
                       input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
                       dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3],
                                     "input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})

369
370
    def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
        import os
371
372
373
        from PIL import Image
        from torchvision import transforms

374
375
376
        data_dir = os.path.join(os.path.dirname(__file__), "assets")
        path = os.path.join(data_dir, *rel_path.split("/"))
        image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
377

378
        return transforms.ToTensor()(image)
379

380
381
382
    def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        return ([self.get_image("encode_jpeg/grace_hopper_517x606.jpg", (100, 320))],
                [self.get_image("fakedata/logos/rgb_pytorch.png", (250, 380))])
383
384
385

    def test_faster_rcnn(self):
        images, test_images = self.get_test_images()
386
        dummy_image = [torch.ones(3, 100, 100) * 0.3]
387
        model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
388
389
        model.eval()
        model(images)
390
391
392
        # Test exported model on images of different size, or dummy input
        self.run_model(model, [(images,), (test_images,), (dummy_image,)], input_names=["images_tensors"],
                       output_names=["outputs"],
393
                       dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
394
395
396
                       tolerate_small_mismatch=True)
        # Test exported model for an image with no detections on other images
        self.run_model(model, [(dummy_image,), (images,)], input_names=["images_tensors"],
397
                       output_names=["outputs"],
398
                       dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
399
                       tolerate_small_mismatch=True)
400

401
402
403
404
    # Verify that paste_mask_in_image beahves the same in tracing.
    # This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
    # (since jit_trace witll call _onnx_paste_masks_in_image).
    def test_paste_mask_in_image(self):
405
406
407
408
        # disable profiling
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_set_profiling_mode(False)

409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
        masks = torch.rand(10, 1, 26, 26)
        boxes = torch.rand(10, 4)
        boxes[:, 2:] += torch.rand(10, 2)
        boxes *= 50
        o_im_s = (100, 100)
        from torchvision.models.detection.roi_heads import paste_masks_in_image
        out = paste_masks_in_image(masks, boxes, o_im_s)
        jit_trace = torch.jit.trace(paste_masks_in_image,
                                    (masks, boxes,
                                     [torch.tensor(o_im_s[0]),
                                      torch.tensor(o_im_s[1])]))
        out_trace = jit_trace(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])])

        assert torch.all(out.eq(out_trace))

        masks2 = torch.rand(20, 1, 26, 26)
        boxes2 = torch.rand(20, 4)
        boxes2[:, 2:] += torch.rand(20, 2)
        boxes2 *= 100
        o_im_s2 = (200, 200)
        from torchvision.models.detection.roi_heads import paste_masks_in_image
        out2 = paste_masks_in_image(masks2, boxes2, o_im_s2)
        out_trace2 = jit_trace(masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])])

        assert torch.all(out2.eq(out_trace2))

    def test_mask_rcnn(self):
        images, test_images = self.get_test_images()
437
        dummy_image = [torch.ones(3, 100, 100) * 0.3]
Lara Haidar's avatar
Lara Haidar committed
438
        model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
439
440
        model.eval()
        model(images)
441
442
        # Test exported model on images of different size, or dummy input
        self.run_model(model, [(images,), (test_images,), (dummy_image,)],
443
                       input_names=["images_tensors"],
444
                       output_names=["boxes", "labels", "scores", "masks"],
445
446
                       dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0],
                                     "scores": [0], "masks": [0, 1, 2]},
447
                       tolerate_small_mismatch=True)
448
        # Test exported model for an image with no detections on other images
449
450
451
        self.run_model(model, [(dummy_image,), (images,)],
                       input_names=["images_tensors"],
                       output_names=["boxes", "labels", "scores", "masks"],
452
453
                       dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0],
                                     "scores": [0], "masks": [0, 1, 2]},
454
                       tolerate_small_mismatch=True)
455

456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    # Verify that heatmaps_to_keypoints behaves the same in tracing.
    # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
    # (since jit_trace witll call _heatmaps_to_keypoints).
    def test_heatmaps_to_keypoints(self):
        # disable profiling
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_set_profiling_mode(False)

        maps = torch.rand(10, 1, 26, 26)
        rois = torch.rand(10, 4)
        from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
        out = heatmaps_to_keypoints(maps, rois)
        jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois))
        out_trace = jit_trace(maps, rois)

471
472
        assert_equal(out[0], out_trace[0])
        assert_equal(out[1], out_trace[1])
473
474
475
476
477
478
479

        maps2 = torch.rand(20, 2, 21, 21)
        rois2 = torch.rand(20, 4)
        from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
        out2 = heatmaps_to_keypoints(maps2, rois2)
        out_trace2 = jit_trace(maps2, rois2)

480
481
        assert_equal(out2[0], out_trace2[0])
        assert_equal(out2[1], out_trace2[1])
482

483
    def test_keypoint_rcnn(self):
Lara Haidar's avatar
Lara Haidar committed
484
        images, test_images = self.get_test_images()
485
        dummy_images = [torch.ones(3, 100, 100) * 0.3]
Ksenija Stanojevic's avatar
Ksenija Stanojevic committed
486
        model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
487
        model.eval()
488
        model(images)
489
        self.run_model(model, [(images,), (test_images,), (dummy_images,)],
490
491
                       input_names=["images_tensors"],
                       output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
492
                       dynamic_axes={"images_tensors": [0, 1, 2]},
493
                       tolerate_small_mismatch=True)
Ksenija Stanojevic's avatar
Ksenija Stanojevic committed
494

495
496
497
        self.run_model(model, [(dummy_images,), (test_images,)],
                       input_names=["images_tensors"],
                       output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
498
                       dynamic_axes={"images_tensors": [0, 1, 2]},
499
                       tolerate_small_mismatch=True)
500

501
502
503
504
505
506
507
508
509
510
511
    def test_shufflenet_v2_dynamic_axes(self):
        model = models.shufflenet_v2_x0_5(pretrained=True)
        dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
        test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0)

        self.run_model(model, [(dummy_input,), (test_inputs,)],
                       input_names=["input_images"],
                       output_names=["output"],
                       dynamic_axes={"input_images": {0: 'batch_size'}, "output": {0: 'batch_size'}},
                       tolerate_small_mismatch=True)

512
513

if __name__ == '__main__':
514
    pytest.main([__file__])