test_onnx.py 22.2 KB
Newer Older
1
import io
2
3
4
from collections import OrderedDict
from typing import List, Tuple

5
import pytest
6
import torch
7
8
from common_utils import assert_equal, set_rng_seed
from torchvision import models, ops
9
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
10
from torchvision.models.detection.image_list import ImageList
11
from torchvision.models.detection.roi_heads import RoIHeads
12
from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
13
from torchvision.models.detection.transform import GeneralizedRCNNTransform
14
from torchvision.ops._register_onnx_ops import _onnx_opset_version
15

16
17
18
19
# In environments without onnxruntime we prefer to
# invoke all tests in the repo and have this one skipped rather than fail.
onnxruntime = pytest.importorskip("onnxruntime")

20

21
class TestONNXExporter:
22
    @classmethod
23
    def setup_class(cls):
24
25
        torch.manual_seed(123)

26
27
28
29
30
31
32
33
34
35
    def run_model(
        self,
        model,
        inputs_list,
        tolerate_small_mismatch=False,
        do_constant_folding=True,
        dynamic_axes=None,
        output_names=None,
        input_names=None,
    ):
36
37
38
        model.eval()

        onnx_io = io.BytesIO()
39
40
41
42
        if isinstance(inputs_list[0][-1], dict):
            torch_onnx_input = inputs_list[0] + ({},)
        else:
            torch_onnx_input = inputs_list[0]
43
        # export to onnx with the first input
44
45
46
47
48
49
50
51
52
53
        torch.onnx.export(
            model,
            torch_onnx_input,
            onnx_io,
            do_constant_folding=do_constant_folding,
            opset_version=_onnx_opset_version,
            dynamic_axes=dynamic_axes,
            input_names=input_names,
            output_names=output_names,
        )
54
        # validate the exported model with onnx runtime
55
56
        for test_inputs in inputs_list:
            with torch.no_grad():
57
                if isinstance(test_inputs, torch.Tensor) or isinstance(test_inputs, list):
58
59
60
61
                    test_inputs = (test_inputs,)
                test_ouputs = model(*test_inputs)
                if isinstance(test_ouputs, torch.Tensor):
                    test_ouputs = (test_ouputs,)
62
            self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)
63

64
    def ort_validate(self, onnx_io, inputs, outputs, tolerate_small_mismatch=False):
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

        inputs, _ = torch.jit._flatten(inputs)
        outputs, _ = torch.jit._flatten(outputs)

        def to_numpy(tensor):
            if tensor.requires_grad:
                return tensor.detach().cpu().numpy()
            else:
                return tensor.cpu().numpy()

        inputs = list(map(to_numpy, inputs))
        outputs = list(map(to_numpy, outputs))

        ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
        # compute onnxruntime output prediction
80
        ort_inputs = {ort_session.get_inputs()[i].name: inpt for i, inpt in enumerate(inputs)}
81
        ort_outs = ort_session.run(None, ort_inputs)
82

83
        for i in range(0, len(outputs)):
84
85
86
87
            try:
                torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
            except AssertionError as error:
                if tolerate_small_mismatch:
88
                    assert "(0.00%)" in str(error), str(error)
89
                else:
90
                    raise
91
92

    def test_nms(self):
93
94
95
96
        num_boxes = 100
        boxes = torch.rand(num_boxes, 4)
        boxes[:, 2:] += boxes[:, :2]
        scores = torch.randn(num_boxes)
97
98
99
100
101

        class Module(torch.nn.Module):
            def forward(self, boxes, scores):
                return ops.nms(boxes, scores, 0.5)

102
        self.run_model(Module(), [(boxes, scores)])
103

104
105
106
107
108
109
110
111
112
113
114
115
116
    def test_batched_nms(self):
        num_boxes = 100
        boxes = torch.rand(num_boxes, 4)
        boxes[:, 2:] += boxes[:, :2]
        scores = torch.randn(num_boxes)
        idxs = torch.randint(0, 5, size=(num_boxes,))

        class Module(torch.nn.Module):
            def forward(self, boxes, scores, idxs):
                return ops.batched_nms(boxes, scores, idxs, 0.5)

        self.run_model(Module(), [(boxes, scores, idxs)])

117
118
119
120
121
122
123
124
125
126
127
    def test_clip_boxes_to_image(self):
        boxes = torch.randn(5, 4) * 500
        boxes[:, 2:] += boxes[:, :2]
        size = torch.randn(200, 300)

        size_2 = torch.randn(300, 400)

        class Module(torch.nn.Module):
            def forward(self, boxes, size):
                return ops.boxes.clip_boxes_to_image(boxes, size.shape)

128
129
130
        self.run_model(
            Module(), [(boxes, size), (boxes, size_2)], input_names=["boxes", "size"], dynamic_axes={"size": [0, 1]}
        )
131

132
    def test_roi_align(self):
133
134
135
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, 2)
136
        self.run_model(model, [(x, single_roi)])
137

138
139
140
141
142
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, -1)
        self.run_model(model, [(x, single_roi)])

143
    @pytest.mark.skip(reason="ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16.")
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    def test_roi_align_aligned(self):
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, 2, aligned=True)
        self.run_model(model, [(x, single_roi)])

        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
        self.run_model(model, [(x, single_roi)])

        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
        self.run_model(model, [(x, single_roi)])

        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
        self.run_model(model, [(x, single_roi)])

165
166
167
168
169
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
        self.run_model(model, [(x, single_roi)])

170
    @pytest.mark.skip(reason="Issue in exporting ROIAlign with aligned = True for malformed boxes")
171
172
173
174
175
176
    def test_roi_align_malformed_boxes(self):
        x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, 1, aligned=True)
        self.run_model(model, [(x, single_roi)])

177
    def test_roi_pool(self):
178
179
180
181
182
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
        pool_h = 5
        pool_w = 5
        model = ops.RoIPool((pool_h, pool_w), 2)
183
184
        self.run_model(model, [(x, rois)])

185
186
187
    def test_resize_images(self):
        class TransformModule(torch.nn.Module):
            def __init__(self_module):
188
                super().__init__()
189
190
191
192
193
194
195
                self_module.transform = self._init_test_generalized_rcnn_transform()

            def forward(self_module, images):
                return self_module.transform.resize(images, None)[0]

        input = torch.rand(3, 10, 20)
        input_test = torch.rand(3, 100, 150)
196
197
198
        self.run_model(
            TransformModule(), [(input,), (input_test,)], input_names=["input1"], dynamic_axes={"input1": [0, 1, 2]}
        )
199

200
201
202
    def test_transform_images(self):
        class TransformModule(torch.nn.Module):
            def __init__(self_module):
203
                super().__init__()
204
                self_module.transform = self._init_test_generalized_rcnn_transform()
205
206
207
208

            def forward(self_module, images):
                return self_module.transform(images)[0].tensors

209
210
211
        input = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
        input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
        self.run_model(TransformModule(), [(input,), (input_test,)])
212

213
    def _init_test_generalized_rcnn_transform(self):
214
215
        min_size = 100
        max_size = 200
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        image_mean = [0.485, 0.456, 0.406]
        image_std = [0.229, 0.224, 0.225]
        transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
        return transform

    def _init_test_rpn(self):
        anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
        aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
        rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
        out_channels = 256
        rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
        rpn_fg_iou_thresh = 0.7
        rpn_bg_iou_thresh = 0.3
        rpn_batch_size_per_image = 256
        rpn_positive_fraction = 0.5
        rpn_pre_nms_top_n = dict(training=2000, testing=1000)
        rpn_post_nms_top_n = dict(training=2000, testing=1000)
        rpn_nms_thresh = 0.7
234
        rpn_score_thresh = 0.0
235
236

        rpn = RegionProposalNetwork(
237
238
239
240
241
242
243
244
245
246
247
            rpn_anchor_generator,
            rpn_head,
            rpn_fg_iou_thresh,
            rpn_bg_iou_thresh,
            rpn_batch_size_per_image,
            rpn_positive_fraction,
            rpn_pre_nms_top_n,
            rpn_post_nms_top_n,
            rpn_nms_thresh,
            score_thresh=rpn_score_thresh,
        )
248
249
        return rpn

250
251
252
253
254
255
256
257
258
259
260
261
262
    def _init_test_roi_heads_faster_rcnn(self):
        out_channels = 256
        num_classes = 91

        box_fg_iou_thresh = 0.5
        box_bg_iou_thresh = 0.5
        box_batch_size_per_image = 512
        box_positive_fraction = 0.25
        bbox_reg_weights = None
        box_score_thresh = 0.05
        box_nms_thresh = 0.5
        box_detections_per_img = 100

263
        box_roi_pool = ops.MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
264
265
266

        resolution = box_roi_pool.output_size[0]
        representation_size = 1024
267
        box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
268
269

        representation_size = 1024
270
        box_predictor = FastRCNNPredictor(representation_size, num_classes)
271
272

        roi_heads = RoIHeads(
273
274
275
276
277
278
279
            box_roi_pool,
            box_head,
            box_predictor,
            box_fg_iou_thresh,
            box_bg_iou_thresh,
            box_batch_size_per_image,
            box_positive_fraction,
280
            bbox_reg_weights,
281
282
283
284
            box_score_thresh,
            box_nms_thresh,
            box_detections_per_img,
        )
285
286
287
288
289
        return roi_heads

    def get_features(self, images):
        s0, s1 = images.shape[-2:]
        features = [
290
291
292
293
294
            ("0", torch.rand(2, 256, s0 // 4, s1 // 4)),
            ("1", torch.rand(2, 256, s0 // 8, s1 // 8)),
            ("2", torch.rand(2, 256, s0 // 16, s1 // 16)),
            ("3", torch.rand(2, 256, s0 // 32, s1 // 32)),
            ("4", torch.rand(2, 256, s0 // 64, s1 // 64)),
295
296
297
298
        ]
        features = OrderedDict(features)
        return features

299
    def test_rpn(self):
300
301
        set_rng_seed(0)

302
        class RPNModule(torch.nn.Module):
303
            def __init__(self_module):
304
                super().__init__()
305
306
                self_module.rpn = self._init_test_rpn()

307
308
309
            def forward(self_module, images, features):
                images = ImageList(images, [i.shape[-2:] for i in images])
                return self_module.rpn(images, features)
310

311
        images = torch.rand(2, 3, 150, 150)
312
        features = self.get_features(images)
313
314
        images2 = torch.rand(2, 3, 80, 80)
        test_features = self.get_features(images2)
315

316
        model = RPNModule()
317
        model.eval()
318
319
        model(images, features)

320
321
322
323
324
325
326
327
328
329
330
331
332
333
        self.run_model(
            model,
            [(images, features), (images2, test_features)],
            tolerate_small_mismatch=True,
            input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
            dynamic_axes={
                "input1": [0, 1, 2, 3],
                "input2": [0, 1, 2, 3],
                "input3": [0, 1, 2, 3],
                "input4": [0, 1, 2, 3],
                "input5": [0, 1, 2, 3],
                "input6": [0, 1, 2, 3],
            },
        )
334

335
336
337
    def test_multi_scale_roi_align(self):
        class TransformModule(torch.nn.Module):
            def __init__(self):
338
                super().__init__()
339
                self.model = ops.MultiScaleRoIAlign(["feat1", "feat2"], 3, 2)
340
341
342
343
344
345
                self.image_sizes = [(512, 512)]

            def forward(self, input, boxes):
                return self.model(input, boxes, self.image_sizes)

        i = OrderedDict()
346
347
        i["feat1"] = torch.rand(1, 5, 64, 64)
        i["feat2"] = torch.rand(1, 5, 16, 16)
348
349
350
351
        boxes = torch.rand(6, 4) * 256
        boxes[:, 2:] += boxes[:, :2]

        i1 = OrderedDict()
352
353
        i1["feat1"] = torch.rand(1, 5, 64, 64)
        i1["feat2"] = torch.rand(1, 5, 16, 16)
354
355
356
        boxes1 = torch.rand(6, 4) * 256
        boxes1[:, 2:] += boxes1[:, :2]

357
358
359
360
361
362
363
364
365
366
367
368
369
        self.run_model(
            TransformModule(),
            [
                (
                    i,
                    [boxes],
                ),
                (
                    i1,
                    [boxes1],
                ),
            ],
        )
370

371
372
    def test_roi_heads(self):
        class RoiHeadsModule(torch.nn.Module):
373
            def __init__(self_module):
374
                super().__init__()
375
376
377
378
                self_module.transform = self._init_test_generalized_rcnn_transform()
                self_module.rpn = self._init_test_rpn()
                self_module.roi_heads = self._init_test_roi_heads_faster_rcnn()

379
380
381
382
383
            def forward(self_module, images, features):
                original_image_sizes = [img.shape[-2:] for img in images]
                images = ImageList(images, [i.shape[-2:] for i in images])
                proposals, _ = self_module.rpn(images, features)
                detections, _ = self_module.roi_heads(features, proposals, images.image_sizes)
384
                detections = self_module.transform.postprocess(detections, images.image_sizes, original_image_sizes)
385
386
                return detections

387
        images = torch.rand(2, 3, 100, 100)
388
        features = self.get_features(images)
389
390
        images2 = torch.rand(2, 3, 150, 150)
        test_features = self.get_features(images2)
391

392
        model = RoiHeadsModule()
393
        model.eval()
394
        model(images, features)
395

396
397
398
399
400
401
402
403
404
405
406
407
408
409
        self.run_model(
            model,
            [(images, features), (images2, test_features)],
            tolerate_small_mismatch=True,
            input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
            dynamic_axes={
                "input1": [0, 1, 2, 3],
                "input2": [0, 1, 2, 3],
                "input3": [0, 1, 2, 3],
                "input4": [0, 1, 2, 3],
                "input5": [0, 1, 2, 3],
                "input6": [0, 1, 2, 3],
            },
        )
410

411
412
    def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
        import os
413

414
        import torchvision.transforms._pil_constants as _pil_constants
415
        from PIL import Image
416
        from torchvision.transforms import functional as F
417

418
419
        data_dir = os.path.join(os.path.dirname(__file__), "assets")
        path = os.path.join(data_dir, *rel_path.split("/"))
420
        image = Image.open(path).convert("RGB").resize(size, _pil_constants.BILINEAR)
421

422
        return F.convert_image_dtype(F.pil_to_tensor(image))
423

424
    def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
425
426
427
428
        return (
            [self.get_image("encode_jpeg/grace_hopper_517x606.jpg", (100, 320))],
            [self.get_image("fakedata/logos/rgb_pytorch.png", (250, 380))],
        )
429
430
431

    def test_faster_rcnn(self):
        images, test_images = self.get_test_images()
432
        dummy_image = [torch.ones(3, 100, 100) * 0.3]
433
434
435
        model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(
            weights=models.detection.faster_rcnn.FasterRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300
        )
436
437
        model.eval()
        model(images)
438
        # Test exported model on images of different size, or dummy input
439
440
441
442
443
444
445
446
        self.run_model(
            model,
            [(images,), (test_images,), (dummy_image,)],
            input_names=["images_tensors"],
            output_names=["outputs"],
            dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
            tolerate_small_mismatch=True,
        )
447
        # Test exported model for an image with no detections on other images
448
449
450
451
452
453
454
455
        self.run_model(
            model,
            [(dummy_image,), (images,)],
            input_names=["images_tensors"],
            output_names=["outputs"],
            dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
            tolerate_small_mismatch=True,
        )
456

457
458
459
460
461
462
463
464
465
466
    # Verify that paste_mask_in_image beahves the same in tracing.
    # This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
    # (since jit_trace witll call _onnx_paste_masks_in_image).
    def test_paste_mask_in_image(self):
        masks = torch.rand(10, 1, 26, 26)
        boxes = torch.rand(10, 4)
        boxes[:, 2:] += torch.rand(10, 2)
        boxes *= 50
        o_im_s = (100, 100)
        from torchvision.models.detection.roi_heads import paste_masks_in_image
467

468
        out = paste_masks_in_image(masks, boxes, o_im_s)
469
470
471
        jit_trace = torch.jit.trace(
            paste_masks_in_image, (masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])])
        )
472
473
474
475
476
477
478
479
480
481
        out_trace = jit_trace(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])])

        assert torch.all(out.eq(out_trace))

        masks2 = torch.rand(20, 1, 26, 26)
        boxes2 = torch.rand(20, 4)
        boxes2[:, 2:] += torch.rand(20, 2)
        boxes2 *= 100
        o_im_s2 = (200, 200)
        from torchvision.models.detection.roi_heads import paste_masks_in_image
482

483
484
485
486
487
488
489
        out2 = paste_masks_in_image(masks2, boxes2, o_im_s2)
        out_trace2 = jit_trace(masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])])

        assert torch.all(out2.eq(out_trace2))

    def test_mask_rcnn(self):
        images, test_images = self.get_test_images()
490
        dummy_image = [torch.ones(3, 100, 100) * 0.3]
491
492
493
        model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(
            weights=models.detection.mask_rcnn.MaskRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300
        )
494
495
        model.eval()
        model(images)
496
        # Test exported model on images of different size, or dummy input
497
498
499
500
501
502
503
504
505
506
507
508
509
510
        self.run_model(
            model,
            [(images,), (test_images,), (dummy_image,)],
            input_names=["images_tensors"],
            output_names=["boxes", "labels", "scores", "masks"],
            dynamic_axes={
                "images_tensors": [0, 1, 2],
                "boxes": [0, 1],
                "labels": [0],
                "scores": [0],
                "masks": [0, 1, 2],
            },
            tolerate_small_mismatch=True,
        )
511
        # Test exported model for an image with no detections on other images
512
513
514
515
516
517
518
519
520
521
522
523
524
525
        self.run_model(
            model,
            [(dummy_image,), (images,)],
            input_names=["images_tensors"],
            output_names=["boxes", "labels", "scores", "masks"],
            dynamic_axes={
                "images_tensors": [0, 1, 2],
                "boxes": [0, 1],
                "labels": [0],
                "scores": [0],
                "masks": [0, 1, 2],
            },
            tolerate_small_mismatch=True,
        )
526

527
528
529
530
531
532
533
    # Verify that heatmaps_to_keypoints behaves the same in tracing.
    # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
    # (since jit_trace witll call _heatmaps_to_keypoints).
    def test_heatmaps_to_keypoints(self):
        maps = torch.rand(10, 1, 26, 26)
        rois = torch.rand(10, 4)
        from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
534

535
536
537
538
        out = heatmaps_to_keypoints(maps, rois)
        jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois))
        out_trace = jit_trace(maps, rois)

539
540
        assert_equal(out[0], out_trace[0])
        assert_equal(out[1], out_trace[1])
541
542
543
544

        maps2 = torch.rand(20, 2, 21, 21)
        rois2 = torch.rand(20, 4)
        from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
545

546
547
548
        out2 = heatmaps_to_keypoints(maps2, rois2)
        out_trace2 = jit_trace(maps2, rois2)

549
550
        assert_equal(out2[0], out_trace2[0])
        assert_equal(out2[1], out_trace2[1])
551

552
    def test_keypoint_rcnn(self):
Lara Haidar's avatar
Lara Haidar committed
553
        images, test_images = self.get_test_images()
554
        dummy_images = [torch.ones(3, 100, 100) * 0.3]
555
556
557
        model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(
            weights=models.detection.keypoint_rcnn.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300
        )
558
        model.eval()
559
        model(images)
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
        self.run_model(
            model,
            [(images,), (test_images,), (dummy_images,)],
            input_names=["images_tensors"],
            output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
            dynamic_axes={"images_tensors": [0, 1, 2]},
            tolerate_small_mismatch=True,
        )

        self.run_model(
            model,
            [(dummy_images,), (test_images,)],
            input_names=["images_tensors"],
            output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
            dynamic_axes={"images_tensors": [0, 1, 2]},
            tolerate_small_mismatch=True,
        )
577

578
    def test_shufflenet_v2_dynamic_axes(self):
579
        model = models.shufflenet_v2_x0_5(weights=models.ShuffleNet_V2_X0_5_Weights.DEFAULT)
580
581
582
        dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
        test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0)

583
584
585
586
587
588
589
590
        self.run_model(
            model,
            [(dummy_input,), (test_inputs,)],
            input_names=["input_images"],
            output_names=["output"],
            dynamic_axes={"input_images": {0: "batch_size"}, "output": {0: "batch_size"}},
            tolerate_small_mismatch=True,
        )
591

592

593
if __name__ == "__main__":
594
    pytest.main([__file__])