"...unconditional_image_generation/train_unconditional.py" did not exist on "0e13d3293c75b4844c8e5832a9b0cceeb9650be3"
test_onnx.py 22.5 KB
Newer Older
1
import io
2
from collections import OrderedDict
3
from typing import List, Optional, Tuple
4

5
import pytest
6
import torch
7
8
from common_utils import assert_equal, set_rng_seed
from torchvision import models, ops
9
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
10
from torchvision.models.detection.image_list import ImageList
11
from torchvision.models.detection.roi_heads import RoIHeads
12
from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
13
from torchvision.models.detection.transform import GeneralizedRCNNTransform
14
from torchvision.ops import _register_onnx_ops
15

16
17
18
19
# In environments without onnxruntime we prefer to
# invoke all tests in the repo and have this one skipped rather than fail.
onnxruntime = pytest.importorskip("onnxruntime")

20

21
class TestONNXExporter:
22
    @classmethod
23
    def setup_class(cls):
24
25
        torch.manual_seed(123)

26
27
28
29
30
31
32
33
34
    def run_model(
        self,
        model,
        inputs_list,
        tolerate_small_mismatch=False,
        do_constant_folding=True,
        dynamic_axes=None,
        output_names=None,
        input_names=None,
35
        opset_version: Optional[int] = None,
36
    ):
37
38
39
        if opset_version is None:
            opset_version = _register_onnx_ops.base_onnx_opset_version

40
41
42
        model.eval()

        onnx_io = io.BytesIO()
43
44
45
46
        if isinstance(inputs_list[0][-1], dict):
            torch_onnx_input = inputs_list[0] + ({},)
        else:
            torch_onnx_input = inputs_list[0]
47
        # export to onnx with the first input
48
49
50
51
52
        torch.onnx.export(
            model,
            torch_onnx_input,
            onnx_io,
            do_constant_folding=do_constant_folding,
53
            opset_version=opset_version,
54
55
56
            dynamic_axes=dynamic_axes,
            input_names=input_names,
            output_names=output_names,
57
            verbose=True,
58
        )
59
        # validate the exported model with onnx runtime
60
61
        for test_inputs in inputs_list:
            with torch.no_grad():
62
                if isinstance(test_inputs, torch.Tensor) or isinstance(test_inputs, list):
63
64
65
66
                    test_inputs = (test_inputs,)
                test_ouputs = model(*test_inputs)
                if isinstance(test_ouputs, torch.Tensor):
                    test_ouputs = (test_ouputs,)
67
            self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)
68

69
    def ort_validate(self, onnx_io, inputs, outputs, tolerate_small_mismatch=False):
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

        inputs, _ = torch.jit._flatten(inputs)
        outputs, _ = torch.jit._flatten(outputs)

        def to_numpy(tensor):
            if tensor.requires_grad:
                return tensor.detach().cpu().numpy()
            else:
                return tensor.cpu().numpy()

        inputs = list(map(to_numpy, inputs))
        outputs = list(map(to_numpy, outputs))

        ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
        # compute onnxruntime output prediction
85
        ort_inputs = {ort_session.get_inputs()[i].name: inpt for i, inpt in enumerate(inputs)}
86
        ort_outs = ort_session.run(None, ort_inputs)
87

88
        for i in range(0, len(outputs)):
89
90
91
92
            try:
                torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
            except AssertionError as error:
                if tolerate_small_mismatch:
93
                    assert "(0.00%)" in str(error), str(error)
94
                else:
95
                    raise
96
97

    def test_nms(self):
98
99
100
101
        num_boxes = 100
        boxes = torch.rand(num_boxes, 4)
        boxes[:, 2:] += boxes[:, :2]
        scores = torch.randn(num_boxes)
102
103
104
105
106

        class Module(torch.nn.Module):
            def forward(self, boxes, scores):
                return ops.nms(boxes, scores, 0.5)

107
        self.run_model(Module(), [(boxes, scores)])
108

109
110
111
112
113
114
115
116
117
118
119
120
121
    def test_batched_nms(self):
        num_boxes = 100
        boxes = torch.rand(num_boxes, 4)
        boxes[:, 2:] += boxes[:, :2]
        scores = torch.randn(num_boxes)
        idxs = torch.randint(0, 5, size=(num_boxes,))

        class Module(torch.nn.Module):
            def forward(self, boxes, scores, idxs):
                return ops.batched_nms(boxes, scores, idxs, 0.5)

        self.run_model(Module(), [(boxes, scores, idxs)])

122
123
124
125
126
127
128
129
130
131
132
    def test_clip_boxes_to_image(self):
        boxes = torch.randn(5, 4) * 500
        boxes[:, 2:] += boxes[:, :2]
        size = torch.randn(200, 300)

        size_2 = torch.randn(300, 400)

        class Module(torch.nn.Module):
            def forward(self, boxes, size):
                return ops.boxes.clip_boxes_to_image(boxes, size.shape)

133
134
135
        self.run_model(
            Module(), [(boxes, size), (boxes, size_2)], input_names=["boxes", "size"], dynamic_axes={"size": [0, 1]}
        )
136

137
    def test_roi_align(self):
138
139
140
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, 2)
141
        self.run_model(model, [(x, single_roi)])
142

143
144
145
146
147
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, -1)
        self.run_model(model, [(x, single_roi)])

148
    def test_roi_align_aligned(self):
149
        supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
150
151
152
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, 2, aligned=True)
153
        self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
154
155
156
157

        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
158
        self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
159
160
161
162

        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
163
        self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
164
165
166
167

        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
168
        self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
169

170
171
172
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
173
        self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
174

175
    def test_roi_align_malformed_boxes(self):
176
        supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
177
178
179
        x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, 1, aligned=True)
180
        self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
181

182
    def test_roi_pool(self):
183
184
185
186
187
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
        pool_h = 5
        pool_w = 5
        model = ops.RoIPool((pool_h, pool_w), 2)
188
189
        self.run_model(model, [(x, rois)])

190
191
192
    def test_resize_images(self):
        class TransformModule(torch.nn.Module):
            def __init__(self_module):
193
                super().__init__()
194
195
196
197
198
199
200
                self_module.transform = self._init_test_generalized_rcnn_transform()

            def forward(self_module, images):
                return self_module.transform.resize(images, None)[0]

        input = torch.rand(3, 10, 20)
        input_test = torch.rand(3, 100, 150)
201
202
203
        self.run_model(
            TransformModule(), [(input,), (input_test,)], input_names=["input1"], dynamic_axes={"input1": [0, 1, 2]}
        )
204

205
206
207
    def test_transform_images(self):
        class TransformModule(torch.nn.Module):
            def __init__(self_module):
208
                super().__init__()
209
                self_module.transform = self._init_test_generalized_rcnn_transform()
210
211
212
213

            def forward(self_module, images):
                return self_module.transform(images)[0].tensors

214
215
216
        input = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
        input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
        self.run_model(TransformModule(), [(input,), (input_test,)])
217

218
    def _init_test_generalized_rcnn_transform(self):
219
220
        min_size = 100
        max_size = 200
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        image_mean = [0.485, 0.456, 0.406]
        image_std = [0.229, 0.224, 0.225]
        transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
        return transform

    def _init_test_rpn(self):
        anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
        aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
        rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
        out_channels = 256
        rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
        rpn_fg_iou_thresh = 0.7
        rpn_bg_iou_thresh = 0.3
        rpn_batch_size_per_image = 256
        rpn_positive_fraction = 0.5
        rpn_pre_nms_top_n = dict(training=2000, testing=1000)
        rpn_post_nms_top_n = dict(training=2000, testing=1000)
        rpn_nms_thresh = 0.7
239
        rpn_score_thresh = 0.0
240
241

        rpn = RegionProposalNetwork(
242
243
244
245
246
247
248
249
250
251
252
            rpn_anchor_generator,
            rpn_head,
            rpn_fg_iou_thresh,
            rpn_bg_iou_thresh,
            rpn_batch_size_per_image,
            rpn_positive_fraction,
            rpn_pre_nms_top_n,
            rpn_post_nms_top_n,
            rpn_nms_thresh,
            score_thresh=rpn_score_thresh,
        )
253
254
        return rpn

255
256
257
258
259
260
261
262
263
264
265
266
267
    def _init_test_roi_heads_faster_rcnn(self):
        out_channels = 256
        num_classes = 91

        box_fg_iou_thresh = 0.5
        box_bg_iou_thresh = 0.5
        box_batch_size_per_image = 512
        box_positive_fraction = 0.25
        bbox_reg_weights = None
        box_score_thresh = 0.05
        box_nms_thresh = 0.5
        box_detections_per_img = 100

268
        box_roi_pool = ops.MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
269
270
271

        resolution = box_roi_pool.output_size[0]
        representation_size = 1024
272
        box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
273
274

        representation_size = 1024
275
        box_predictor = FastRCNNPredictor(representation_size, num_classes)
276
277

        roi_heads = RoIHeads(
278
279
280
281
282
283
284
            box_roi_pool,
            box_head,
            box_predictor,
            box_fg_iou_thresh,
            box_bg_iou_thresh,
            box_batch_size_per_image,
            box_positive_fraction,
285
            bbox_reg_weights,
286
287
288
289
            box_score_thresh,
            box_nms_thresh,
            box_detections_per_img,
        )
290
291
292
293
294
        return roi_heads

    def get_features(self, images):
        s0, s1 = images.shape[-2:]
        features = [
295
296
297
298
299
            ("0", torch.rand(2, 256, s0 // 4, s1 // 4)),
            ("1", torch.rand(2, 256, s0 // 8, s1 // 8)),
            ("2", torch.rand(2, 256, s0 // 16, s1 // 16)),
            ("3", torch.rand(2, 256, s0 // 32, s1 // 32)),
            ("4", torch.rand(2, 256, s0 // 64, s1 // 64)),
300
301
302
303
        ]
        features = OrderedDict(features)
        return features

304
    def test_rpn(self):
305
306
        set_rng_seed(0)

307
        class RPNModule(torch.nn.Module):
308
            def __init__(self_module):
309
                super().__init__()
310
311
                self_module.rpn = self._init_test_rpn()

312
313
314
            def forward(self_module, images, features):
                images = ImageList(images, [i.shape[-2:] for i in images])
                return self_module.rpn(images, features)
315

316
        images = torch.rand(2, 3, 150, 150)
317
        features = self.get_features(images)
318
319
        images2 = torch.rand(2, 3, 80, 80)
        test_features = self.get_features(images2)
320

321
        model = RPNModule()
322
        model.eval()
323
324
        model(images, features)

325
326
327
328
329
330
331
332
333
334
335
336
337
338
        self.run_model(
            model,
            [(images, features), (images2, test_features)],
            tolerate_small_mismatch=True,
            input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
            dynamic_axes={
                "input1": [0, 1, 2, 3],
                "input2": [0, 1, 2, 3],
                "input3": [0, 1, 2, 3],
                "input4": [0, 1, 2, 3],
                "input5": [0, 1, 2, 3],
                "input6": [0, 1, 2, 3],
            },
        )
339

340
341
342
    def test_multi_scale_roi_align(self):
        class TransformModule(torch.nn.Module):
            def __init__(self):
343
                super().__init__()
344
                self.model = ops.MultiScaleRoIAlign(["feat1", "feat2"], 3, 2)
345
346
347
348
349
350
                self.image_sizes = [(512, 512)]

            def forward(self, input, boxes):
                return self.model(input, boxes, self.image_sizes)

        i = OrderedDict()
351
352
        i["feat1"] = torch.rand(1, 5, 64, 64)
        i["feat2"] = torch.rand(1, 5, 16, 16)
353
354
355
356
        boxes = torch.rand(6, 4) * 256
        boxes[:, 2:] += boxes[:, :2]

        i1 = OrderedDict()
357
358
        i1["feat1"] = torch.rand(1, 5, 64, 64)
        i1["feat2"] = torch.rand(1, 5, 16, 16)
359
360
361
        boxes1 = torch.rand(6, 4) * 256
        boxes1[:, 2:] += boxes1[:, :2]

362
363
364
365
366
367
368
369
370
371
372
373
374
        self.run_model(
            TransformModule(),
            [
                (
                    i,
                    [boxes],
                ),
                (
                    i1,
                    [boxes1],
                ),
            ],
        )
375

376
377
    def test_roi_heads(self):
        class RoiHeadsModule(torch.nn.Module):
378
            def __init__(self_module):
379
                super().__init__()
380
381
382
383
                self_module.transform = self._init_test_generalized_rcnn_transform()
                self_module.rpn = self._init_test_rpn()
                self_module.roi_heads = self._init_test_roi_heads_faster_rcnn()

384
385
386
387
388
            def forward(self_module, images, features):
                original_image_sizes = [img.shape[-2:] for img in images]
                images = ImageList(images, [i.shape[-2:] for i in images])
                proposals, _ = self_module.rpn(images, features)
                detections, _ = self_module.roi_heads(features, proposals, images.image_sizes)
389
                detections = self_module.transform.postprocess(detections, images.image_sizes, original_image_sizes)
390
391
                return detections

392
        images = torch.rand(2, 3, 100, 100)
393
        features = self.get_features(images)
394
395
        images2 = torch.rand(2, 3, 150, 150)
        test_features = self.get_features(images2)
396

397
        model = RoiHeadsModule()
398
        model.eval()
399
        model(images, features)
400

401
402
403
404
405
406
407
408
409
410
411
412
413
414
        self.run_model(
            model,
            [(images, features), (images2, test_features)],
            tolerate_small_mismatch=True,
            input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
            dynamic_axes={
                "input1": [0, 1, 2, 3],
                "input2": [0, 1, 2, 3],
                "input3": [0, 1, 2, 3],
                "input4": [0, 1, 2, 3],
                "input5": [0, 1, 2, 3],
                "input6": [0, 1, 2, 3],
            },
        )
415

416
417
    def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
        import os
418

419
        import torchvision.transforms._pil_constants as _pil_constants
420
        from PIL import Image
421
        from torchvision.transforms import functional as F
422

423
424
        data_dir = os.path.join(os.path.dirname(__file__), "assets")
        path = os.path.join(data_dir, *rel_path.split("/"))
425
        image = Image.open(path).convert("RGB").resize(size, _pil_constants.BILINEAR)
426

427
        return F.convert_image_dtype(F.pil_to_tensor(image))
428

429
    def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
430
431
432
433
        return (
            [self.get_image("encode_jpeg/grace_hopper_517x606.jpg", (100, 320))],
            [self.get_image("fakedata/logos/rgb_pytorch.png", (250, 380))],
        )
434
435
436

    def test_faster_rcnn(self):
        images, test_images = self.get_test_images()
437
        dummy_image = [torch.ones(3, 100, 100) * 0.3]
438
439
440
        model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(
            weights=models.detection.faster_rcnn.FasterRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300
        )
441
442
        model.eval()
        model(images)
443
        # Test exported model on images of different size, or dummy input
444
445
446
447
448
449
450
451
        self.run_model(
            model,
            [(images,), (test_images,), (dummy_image,)],
            input_names=["images_tensors"],
            output_names=["outputs"],
            dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
            tolerate_small_mismatch=True,
        )
452
        # Test exported model for an image with no detections on other images
453
454
455
456
457
458
459
460
        self.run_model(
            model,
            [(dummy_image,), (images,)],
            input_names=["images_tensors"],
            output_names=["outputs"],
            dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
            tolerate_small_mismatch=True,
        )
461

462
463
464
465
466
467
468
469
470
471
    # Verify that paste_mask_in_image beahves the same in tracing.
    # This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
    # (since jit_trace witll call _onnx_paste_masks_in_image).
    def test_paste_mask_in_image(self):
        masks = torch.rand(10, 1, 26, 26)
        boxes = torch.rand(10, 4)
        boxes[:, 2:] += torch.rand(10, 2)
        boxes *= 50
        o_im_s = (100, 100)
        from torchvision.models.detection.roi_heads import paste_masks_in_image
472

473
        out = paste_masks_in_image(masks, boxes, o_im_s)
474
475
476
        jit_trace = torch.jit.trace(
            paste_masks_in_image, (masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])])
        )
477
478
479
480
481
482
483
484
485
486
        out_trace = jit_trace(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])])

        assert torch.all(out.eq(out_trace))

        masks2 = torch.rand(20, 1, 26, 26)
        boxes2 = torch.rand(20, 4)
        boxes2[:, 2:] += torch.rand(20, 2)
        boxes2 *= 100
        o_im_s2 = (200, 200)
        from torchvision.models.detection.roi_heads import paste_masks_in_image
487

488
489
490
491
492
493
494
        out2 = paste_masks_in_image(masks2, boxes2, o_im_s2)
        out_trace2 = jit_trace(masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])])

        assert torch.all(out2.eq(out_trace2))

    def test_mask_rcnn(self):
        images, test_images = self.get_test_images()
495
        dummy_image = [torch.ones(3, 100, 100) * 0.3]
496
497
498
        model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(
            weights=models.detection.mask_rcnn.MaskRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300
        )
499
500
        model.eval()
        model(images)
501
        # Test exported model on images of different size, or dummy input
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        self.run_model(
            model,
            [(images,), (test_images,), (dummy_image,)],
            input_names=["images_tensors"],
            output_names=["boxes", "labels", "scores", "masks"],
            dynamic_axes={
                "images_tensors": [0, 1, 2],
                "boxes": [0, 1],
                "labels": [0],
                "scores": [0],
                "masks": [0, 1, 2],
            },
            tolerate_small_mismatch=True,
        )
516
        # Test exported model for an image with no detections on other images
517
518
519
520
521
522
523
524
525
526
527
528
529
530
        self.run_model(
            model,
            [(dummy_image,), (images,)],
            input_names=["images_tensors"],
            output_names=["boxes", "labels", "scores", "masks"],
            dynamic_axes={
                "images_tensors": [0, 1, 2],
                "boxes": [0, 1],
                "labels": [0],
                "scores": [0],
                "masks": [0, 1, 2],
            },
            tolerate_small_mismatch=True,
        )
531

532
533
534
535
536
537
538
    # Verify that heatmaps_to_keypoints behaves the same in tracing.
    # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
    # (since jit_trace witll call _heatmaps_to_keypoints).
    def test_heatmaps_to_keypoints(self):
        maps = torch.rand(10, 1, 26, 26)
        rois = torch.rand(10, 4)
        from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
539

540
541
542
543
        out = heatmaps_to_keypoints(maps, rois)
        jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois))
        out_trace = jit_trace(maps, rois)

544
545
        assert_equal(out[0], out_trace[0])
        assert_equal(out[1], out_trace[1])
546
547
548
549

        maps2 = torch.rand(20, 2, 21, 21)
        rois2 = torch.rand(20, 4)
        from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
550

551
552
553
        out2 = heatmaps_to_keypoints(maps2, rois2)
        out_trace2 = jit_trace(maps2, rois2)

554
555
        assert_equal(out2[0], out_trace2[0])
        assert_equal(out2[1], out_trace2[1])
556

557
    def test_keypoint_rcnn(self):
Lara Haidar's avatar
Lara Haidar committed
558
        images, test_images = self.get_test_images()
559
        dummy_images = [torch.ones(3, 100, 100) * 0.3]
560
561
562
        model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(
            weights=models.detection.keypoint_rcnn.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300
        )
563
        model.eval()
564
        model(images)
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
        self.run_model(
            model,
            [(images,), (test_images,), (dummy_images,)],
            input_names=["images_tensors"],
            output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
            dynamic_axes={"images_tensors": [0, 1, 2]},
            tolerate_small_mismatch=True,
        )

        self.run_model(
            model,
            [(dummy_images,), (test_images,)],
            input_names=["images_tensors"],
            output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
            dynamic_axes={"images_tensors": [0, 1, 2]},
            tolerate_small_mismatch=True,
        )
582

583
    def test_shufflenet_v2_dynamic_axes(self):
584
        model = models.shufflenet_v2_x0_5(weights=models.ShuffleNet_V2_X0_5_Weights.DEFAULT)
585
586
587
        dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
        test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0)

588
589
590
591
592
593
594
595
        self.run_model(
            model,
            [(dummy_input,), (test_inputs,)],
            input_names=["input_images"],
            output_names=["output"],
            dynamic_axes={"input_images": {0: "batch_size"}, "output": {0: "batch_size"}},
            tolerate_small_mismatch=True,
        )
596

597

598
if __name__ == "__main__":
599
    pytest.main([__file__])