"vscode:/vscode.git/clone" did not exist on "fc2c3a3d8e47fc8c981970fb72f33c4283509fb1"
test_onnx.py 21.4 KB
Newer Older
1
from common_utils import set_rng_seed, assert_equal
2
import io
3
import pytest
4
5
import torch
from torchvision import ops
6
from torchvision import models
7
from torchvision.models.detection.image_list import ImageList
8
from torchvision.models.detection.transform import GeneralizedRCNNTransform
9
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
10
11
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
12

13
from collections import OrderedDict
14
from typing import List, Tuple
15

16
from torchvision.ops._register_onnx_ops import _onnx_opset_version
17

18
19
20
21
# In environments without onnxruntime we prefer to
# invoke all tests in the repo and have this one skipped rather than fail.
onnxruntime = pytest.importorskip("onnxruntime")

22

23
class TestONNXExporter:
24
    @classmethod
25
    def setup_class(cls):
26
27
        torch.manual_seed(123)

28
29
    def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None,
                  output_names=None, input_names=None):
30
31
32
        model.eval()

        onnx_io = io.BytesIO()
33
34
35
36
        if isinstance(inputs_list[0][-1], dict):
            torch_onnx_input = inputs_list[0] + ({},)
        else:
            torch_onnx_input = inputs_list[0]
37
        # export to onnx with the first input
38
        torch.onnx.export(model, torch_onnx_input, onnx_io,
39
40
                          do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version,
                          dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names)
41
        # validate the exported model with onnx runtime
42
43
44
45
46
47
48
49
        for test_inputs in inputs_list:
            with torch.no_grad():
                if isinstance(test_inputs, torch.Tensor) or \
                   isinstance(test_inputs, list):
                    test_inputs = (test_inputs,)
                test_ouputs = model(*test_inputs)
                if isinstance(test_ouputs, torch.Tensor):
                    test_ouputs = (test_ouputs,)
50
            self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)
51

52
    def ort_validate(self, onnx_io, inputs, outputs, tolerate_small_mismatch=False):
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

        inputs, _ = torch.jit._flatten(inputs)
        outputs, _ = torch.jit._flatten(outputs)

        def to_numpy(tensor):
            if tensor.requires_grad:
                return tensor.detach().cpu().numpy()
            else:
                return tensor.cpu().numpy()

        inputs = list(map(to_numpy, inputs))
        outputs = list(map(to_numpy, outputs))

        ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
        # compute onnxruntime output prediction
        ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))
        ort_outs = ort_session.run(None, ort_inputs)
70

71
        for i in range(0, len(outputs)):
72
73
74
75
            try:
                torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
            except AssertionError as error:
                if tolerate_small_mismatch:
76
                    assert "(0.00%)" in str(error), str(error)
77
                else:
78
                    raise
79
80

    def test_nms(self):
81
82
83
84
        num_boxes = 100
        boxes = torch.rand(num_boxes, 4)
        boxes[:, 2:] += boxes[:, :2]
        scores = torch.randn(num_boxes)
85
86
87
88
89

        class Module(torch.nn.Module):
            def forward(self, boxes, scores):
                return ops.nms(boxes, scores, 0.5)

90
        self.run_model(Module(), [(boxes, scores)])
91

92
93
94
95
96
97
98
99
100
101
102
103
104
    def test_batched_nms(self):
        num_boxes = 100
        boxes = torch.rand(num_boxes, 4)
        boxes[:, 2:] += boxes[:, :2]
        scores = torch.randn(num_boxes)
        idxs = torch.randint(0, 5, size=(num_boxes,))

        class Module(torch.nn.Module):
            def forward(self, boxes, scores, idxs):
                return ops.batched_nms(boxes, scores, idxs, 0.5)

        self.run_model(Module(), [(boxes, scores, idxs)])

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    def test_clip_boxes_to_image(self):
        boxes = torch.randn(5, 4) * 500
        boxes[:, 2:] += boxes[:, :2]
        size = torch.randn(200, 300)

        size_2 = torch.randn(300, 400)

        class Module(torch.nn.Module):
            def forward(self, boxes, size):
                return ops.boxes.clip_boxes_to_image(boxes, size.shape)

        self.run_model(Module(), [(boxes, size), (boxes, size_2)],
                       input_names=["boxes", "size"],
                       dynamic_axes={"size": [0, 1]})

120
    def test_roi_align(self):
121
122
123
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, 2)
124
        self.run_model(model, [(x, single_roi)])
125

126
127
128
129
130
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, -1)
        self.run_model(model, [(x, single_roi)])

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    def test_roi_align_aligned(self):
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, 2, aligned=True)
        self.run_model(model, [(x, single_roi)])

        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
        self.run_model(model, [(x, single_roi)])

        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
        self.run_model(model, [(x, single_roi)])

        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
        self.run_model(model, [(x, single_roi)])

152
153
154
155
156
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
        model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
        self.run_model(model, [(x, single_roi)])

157
    @pytest.mark.skip(reason="Issue in exporting ROIAlign with aligned = True for malformed boxes")
158
159
160
161
162
163
    def test_roi_align_malformed_boxes(self):
        x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
        single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
        model = ops.RoIAlign((5, 5), 1, 1, aligned=True)
        self.run_model(model, [(x, single_roi)])

164
    def test_roi_pool(self):
165
166
167
168
169
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
        rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
        pool_h = 5
        pool_w = 5
        model = ops.RoIPool((pool_h, pool_w), 2)
170
171
        self.run_model(model, [(x, rois)])

172
173
174
175
176
177
178
179
180
181
182
183
    def test_resize_images(self):
        class TransformModule(torch.nn.Module):
            def __init__(self_module):
                super(TransformModule, self_module).__init__()
                self_module.transform = self._init_test_generalized_rcnn_transform()

            def forward(self_module, images):
                return self_module.transform.resize(images, None)[0]

        input = torch.rand(3, 10, 20)
        input_test = torch.rand(3, 100, 150)
        self.run_model(TransformModule(), [(input,), (input_test,)],
184
                       input_names=["input1"], dynamic_axes={"input1": [0, 1, 2]})
185

186
187
188
189
190
    def test_transform_images(self):

        class TransformModule(torch.nn.Module):
            def __init__(self_module):
                super(TransformModule, self_module).__init__()
191
                self_module.transform = self._init_test_generalized_rcnn_transform()
192
193
194
195

            def forward(self_module, images):
                return self_module.transform(images)[0].tensors

196
197
198
        input = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
        input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
        self.run_model(TransformModule(), [(input,), (input_test,)])
199

200
    def _init_test_generalized_rcnn_transform(self):
201
202
        min_size = 100
        max_size = 200
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        image_mean = [0.485, 0.456, 0.406]
        image_std = [0.229, 0.224, 0.225]
        transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
        return transform

    def _init_test_rpn(self):
        anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
        aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
        rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
        out_channels = 256
        rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
        rpn_fg_iou_thresh = 0.7
        rpn_bg_iou_thresh = 0.3
        rpn_batch_size_per_image = 256
        rpn_positive_fraction = 0.5
        rpn_pre_nms_top_n = dict(training=2000, testing=1000)
        rpn_post_nms_top_n = dict(training=2000, testing=1000)
        rpn_nms_thresh = 0.7
221
        rpn_score_thresh = 0.0
222
223
224
225
226

        rpn = RegionProposalNetwork(
            rpn_anchor_generator, rpn_head,
            rpn_fg_iou_thresh, rpn_bg_iou_thresh,
            rpn_batch_size_per_image, rpn_positive_fraction,
227
228
            rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh,
            score_thresh=rpn_score_thresh)
229
230
        return rpn

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    def _init_test_roi_heads_faster_rcnn(self):
        out_channels = 256
        num_classes = 91

        box_fg_iou_thresh = 0.5
        box_bg_iou_thresh = 0.5
        box_batch_size_per_image = 512
        box_positive_fraction = 0.25
        bbox_reg_weights = None
        box_score_thresh = 0.05
        box_nms_thresh = 0.5
        box_detections_per_img = 100

        box_roi_pool = ops.MultiScaleRoIAlign(
            featmap_names=['0', '1', '2', '3'],
            output_size=7,
            sampling_ratio=2)

        resolution = box_roi_pool.output_size[0]
        representation_size = 1024
        box_head = TwoMLPHead(
            out_channels * resolution ** 2,
            representation_size)

        representation_size = 1024
        box_predictor = FastRCNNPredictor(
            representation_size,
            num_classes)

        roi_heads = RoIHeads(
            box_roi_pool, box_head, box_predictor,
            box_fg_iou_thresh, box_bg_iou_thresh,
            box_batch_size_per_image, box_positive_fraction,
            bbox_reg_weights,
            box_score_thresh, box_nms_thresh, box_detections_per_img)
        return roi_heads

    def get_features(self, images):
        s0, s1 = images.shape[-2:]
        features = [
            ('0', torch.rand(2, 256, s0 // 4, s1 // 4)),
            ('1', torch.rand(2, 256, s0 // 8, s1 // 8)),
            ('2', torch.rand(2, 256, s0 // 16, s1 // 16)),
            ('3', torch.rand(2, 256, s0 // 32, s1 // 32)),
            ('4', torch.rand(2, 256, s0 // 64, s1 // 64)),
        ]
        features = OrderedDict(features)
        return features

280
    def test_rpn(self):
281
282
        set_rng_seed(0)

283
        class RPNModule(torch.nn.Module):
284
            def __init__(self_module):
285
286
287
                super(RPNModule, self_module).__init__()
                self_module.rpn = self._init_test_rpn()

288
289
290
            def forward(self_module, images, features):
                images = ImageList(images, [i.shape[-2:] for i in images])
                return self_module.rpn(images, features)
291

292
        images = torch.rand(2, 3, 150, 150)
293
        features = self.get_features(images)
294
295
        images2 = torch.rand(2, 3, 80, 80)
        test_features = self.get_features(images2)
296

297
        model = RPNModule()
298
        model.eval()
299
300
301
302
303
304
305
        model(images, features)

        self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True,
                       input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
                       dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3],
                                     "input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3],
                                     "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})
306

307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    def test_multi_scale_roi_align(self):

        class TransformModule(torch.nn.Module):
            def __init__(self):
                super(TransformModule, self).__init__()
                self.model = ops.MultiScaleRoIAlign(['feat1', 'feat2'], 3, 2)
                self.image_sizes = [(512, 512)]

            def forward(self, input, boxes):
                return self.model(input, boxes, self.image_sizes)

        i = OrderedDict()
        i['feat1'] = torch.rand(1, 5, 64, 64)
        i['feat2'] = torch.rand(1, 5, 16, 16)
        boxes = torch.rand(6, 4) * 256
        boxes[:, 2:] += boxes[:, :2]

        i1 = OrderedDict()
        i1['feat1'] = torch.rand(1, 5, 64, 64)
        i1['feat2'] = torch.rand(1, 5, 16, 16)
        boxes1 = torch.rand(6, 4) * 256
        boxes1[:, 2:] += boxes1[:, :2]

        self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)])

332
333
    def test_roi_heads(self):
        class RoiHeadsModule(torch.nn.Module):
334
            def __init__(self_module):
335
336
337
338
339
                super(RoiHeadsModule, self_module).__init__()
                self_module.transform = self._init_test_generalized_rcnn_transform()
                self_module.rpn = self._init_test_rpn()
                self_module.roi_heads = self._init_test_roi_heads_faster_rcnn()

340
341
342
343
344
            def forward(self_module, images, features):
                original_image_sizes = [img.shape[-2:] for img in images]
                images = ImageList(images, [i.shape[-2:] for i in images])
                proposals, _ = self_module.rpn(images, features)
                detections, _ = self_module.roi_heads(features, proposals, images.image_sizes)
345
                detections = self_module.transform.postprocess(detections,
346
347
                                                               images.image_sizes,
                                                               original_image_sizes)
348
349
                return detections

350
        images = torch.rand(2, 3, 100, 100)
351
        features = self.get_features(images)
352
353
        images2 = torch.rand(2, 3, 150, 150)
        test_features = self.get_features(images2)
354

355
        model = RoiHeadsModule()
356
        model.eval()
357
        model(images, features)
358

359
360
361
362
363
        self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True,
                       input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
                       dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3],
                                     "input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})

364
365
    def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
        import os
366
367
368
        from PIL import Image
        from torchvision import transforms

369
370
371
        data_dir = os.path.join(os.path.dirname(__file__), "assets")
        path = os.path.join(data_dir, *rel_path.split("/"))
        image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
372

373
        return transforms.ToTensor()(image)
374

375
376
377
    def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        return ([self.get_image("encode_jpeg/grace_hopper_517x606.jpg", (100, 320))],
                [self.get_image("fakedata/logos/rgb_pytorch.png", (250, 380))])
378
379
380

    def test_faster_rcnn(self):
        images, test_images = self.get_test_images()
381
        dummy_image = [torch.ones(3, 100, 100) * 0.3]
382
        model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
383
384
        model.eval()
        model(images)
385
386
387
        # Test exported model on images of different size, or dummy input
        self.run_model(model, [(images,), (test_images,), (dummy_image,)], input_names=["images_tensors"],
                       output_names=["outputs"],
388
                       dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
389
390
391
                       tolerate_small_mismatch=True)
        # Test exported model for an image with no detections on other images
        self.run_model(model, [(dummy_image,), (images,)], input_names=["images_tensors"],
392
                       output_names=["outputs"],
393
                       dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
394
                       tolerate_small_mismatch=True)
395

396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    # Verify that paste_mask_in_image beahves the same in tracing.
    # This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
    # (since jit_trace witll call _onnx_paste_masks_in_image).
    def test_paste_mask_in_image(self):
        masks = torch.rand(10, 1, 26, 26)
        boxes = torch.rand(10, 4)
        boxes[:, 2:] += torch.rand(10, 2)
        boxes *= 50
        o_im_s = (100, 100)
        from torchvision.models.detection.roi_heads import paste_masks_in_image
        out = paste_masks_in_image(masks, boxes, o_im_s)
        jit_trace = torch.jit.trace(paste_masks_in_image,
                                    (masks, boxes,
                                     [torch.tensor(o_im_s[0]),
                                      torch.tensor(o_im_s[1])]))
        out_trace = jit_trace(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])])

        assert torch.all(out.eq(out_trace))

        masks2 = torch.rand(20, 1, 26, 26)
        boxes2 = torch.rand(20, 4)
        boxes2[:, 2:] += torch.rand(20, 2)
        boxes2 *= 100
        o_im_s2 = (200, 200)
        from torchvision.models.detection.roi_heads import paste_masks_in_image
        out2 = paste_masks_in_image(masks2, boxes2, o_im_s2)
        out_trace2 = jit_trace(masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])])

        assert torch.all(out2.eq(out_trace2))

    def test_mask_rcnn(self):
        images, test_images = self.get_test_images()
428
        dummy_image = [torch.ones(3, 100, 100) * 0.3]
Lara Haidar's avatar
Lara Haidar committed
429
        model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
430
431
        model.eval()
        model(images)
432
433
        # Test exported model on images of different size, or dummy input
        self.run_model(model, [(images,), (test_images,), (dummy_image,)],
434
                       input_names=["images_tensors"],
435
                       output_names=["boxes", "labels", "scores", "masks"],
436
437
                       dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0],
                                     "scores": [0], "masks": [0, 1, 2]},
438
                       tolerate_small_mismatch=True)
439
        # Test exported model for an image with no detections on other images
440
441
442
        self.run_model(model, [(dummy_image,), (images,)],
                       input_names=["images_tensors"],
                       output_names=["boxes", "labels", "scores", "masks"],
443
444
                       dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0],
                                     "scores": [0], "masks": [0, 1, 2]},
445
                       tolerate_small_mismatch=True)
446

447
448
449
450
451
452
453
454
455
456
457
    # Verify that heatmaps_to_keypoints behaves the same in tracing.
    # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
    # (since jit_trace witll call _heatmaps_to_keypoints).
    def test_heatmaps_to_keypoints(self):
        maps = torch.rand(10, 1, 26, 26)
        rois = torch.rand(10, 4)
        from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
        out = heatmaps_to_keypoints(maps, rois)
        jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois))
        out_trace = jit_trace(maps, rois)

458
459
        assert_equal(out[0], out_trace[0])
        assert_equal(out[1], out_trace[1])
460
461
462
463
464
465
466

        maps2 = torch.rand(20, 2, 21, 21)
        rois2 = torch.rand(20, 4)
        from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
        out2 = heatmaps_to_keypoints(maps2, rois2)
        out_trace2 = jit_trace(maps2, rois2)

467
468
        assert_equal(out2[0], out_trace2[0])
        assert_equal(out2[1], out_trace2[1])
469

470
    def test_keypoint_rcnn(self):
Lara Haidar's avatar
Lara Haidar committed
471
        images, test_images = self.get_test_images()
472
        dummy_images = [torch.ones(3, 100, 100) * 0.3]
Ksenija Stanojevic's avatar
Ksenija Stanojevic committed
473
        model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
474
        model.eval()
475
        model(images)
476
        self.run_model(model, [(images,), (test_images,), (dummy_images,)],
477
478
                       input_names=["images_tensors"],
                       output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
479
                       dynamic_axes={"images_tensors": [0, 1, 2]},
480
                       tolerate_small_mismatch=True)
Ksenija Stanojevic's avatar
Ksenija Stanojevic committed
481

482
483
484
        self.run_model(model, [(dummy_images,), (test_images,)],
                       input_names=["images_tensors"],
                       output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
485
                       dynamic_axes={"images_tensors": [0, 1, 2]},
486
                       tolerate_small_mismatch=True)
487

488
489
490
491
492
493
494
495
496
497
498
    def test_shufflenet_v2_dynamic_axes(self):
        model = models.shufflenet_v2_x0_5(pretrained=True)
        dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
        test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0)

        self.run_model(model, [(dummy_input,), (test_inputs,)],
                       input_names=["input_images"],
                       output_names=["output"],
                       dynamic_axes={"input_images": {0: 'batch_size'}, "output": {0: 'batch_size'}},
                       tolerate_small_mismatch=True)

499
500

if __name__ == '__main__':
501
    pytest.main([__file__])