"vscode:/vscode.git/clone" did not exist on "a1f2ded58a502f0d65c86ff6c86f417689a5c8d4"
test_ops.py 75.5 KB
Newer Older
1
import math
2
import os
3
from abc import ABC, abstractmethod
4
from functools import lru_cache
5
from itertools import product
6
from typing import Callable, List, Tuple
7

8
import numpy as np
9
import pytest
10
import torch
11
import torch.fx
12
import torch.nn.functional as F
13
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
14
from PIL import Image
15
from torch import nn, Tensor
16
from torch.autograd import gradcheck
17
from torch.nn.modules.utils import _pair
18
from torchvision import models, ops
19
20
21
from torchvision.models.feature_extraction import get_graph_node_names


22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Context manager for setting deterministic flag and automatically
# resetting it to its original value
class DeterministicGuard:
    def __init__(self, deterministic, *, warn_only=False):
        self.deterministic = deterministic
        self.warn_only = warn_only

    def __enter__(self):
        self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
        self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
        torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)

    def __exit__(self, exception_type, exception_value, traceback):
        torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore)


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class RoIOpTesterModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 2

    def forward(self, a, b):
        self.layer(a, b)


class MultiScaleRoIAlignModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 3

    def forward(self, a, b, c):
        self.layer(a, b, c)


class DeformConvModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 3

    def forward(self, a, b, c):
        self.layer(a, b, c)


class StochasticDepthWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 1

    def forward(self, a):
        self.layer(a)
76
77


78
79
80
81
82
83
84
85
86
87
class DropBlockWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 1

    def forward(self, a):
        self.layer(a)


88
89
90
91
92
93
94
95
96
class PoolWrapper(nn.Module):
    def __init__(self, pool: nn.Module):
        super().__init__()
        self.pool = pool

    def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:
        return self.pool(imgs, boxes)


97
98
class RoIOpTester(ABC):
    dtype = torch.float64
99
100
    mps_dtype = torch.float32
    mps_backward_atol = 2e-2
101

102
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
103
    @pytest.mark.parametrize("contiguous", (True, False))
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    @pytest.mark.parametrize(
        "x_dtype",
        (
            torch.float16,
            torch.float32,
            torch.float64,
        ),
        ids=str,
    )
    def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
        if device == "mps" and x_dtype is torch.float64:
            pytest.skip("MPS does not support float64")

        rois_dtype = x_dtype if rois_dtype is None else rois_dtype

        tol = 1e-5
        if x_dtype is torch.half:
            if device == "mps":
                tol = 5e-3
            else:
                tol = 4e-3

126
        pool_size = 5
127
        # n_channels % (pool_size ** 2) == 0 required for PS operations.
128
        n_channels = 2 * (pool_size**2)
129
        x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
130
131
        if not contiguous:
            x = x.permute(0, 1, 3, 2)
132
133
134
135
136
        rois = torch.tensor(
            [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],  # format is (xyxy)
            dtype=rois_dtype,
            device=device,
        )
137

138
        pool_h, pool_w = pool_size, pool_size
139
140
        with DeterministicGuard(deterministic):
            y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
141
        # the following should be true whether we're running an autocast test or not.
142
        assert y.dtype == x.dtype
143
        gt_y = self.expected_fn(
144
            x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs
145
        )
146

147
        torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
148

149
    @pytest.mark.parametrize("device", cpu_and_cuda())
150
151
152
153
154
155
156
157
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

158
    @pytest.mark.parametrize("device", cpu_and_cuda())
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.float):
        op_obj = self.make_obj().to(device=device)
        graph_module = torch.fx.symbolic_trace(op_obj)
        pool_size = 5
        n_channels = 2 * (pool_size**2)
        x = torch.rand(2, n_channels, 5, 5, dtype=x_dtype, device=device)
        rois = torch.tensor(
            [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],  # format is (xyxy)
            dtype=rois_dtype,
            device=device,
        )
        output_gt = op_obj(x, rois)
        assert output_gt.dtype == x.dtype
        output_fx = graph_module(x, rois)
        assert output_fx.dtype == x.dtype
        tol = 1e-5
        torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)

177
    @pytest.mark.parametrize("seed", range(10))
178
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
179
    @pytest.mark.parametrize("contiguous", (True, False))
180
    def test_backward(self, seed, device, contiguous, deterministic=False):
181
182
183
        atol = self.mps_backward_atol if device == "mps" else 1e-05
        dtype = self.mps_dtype if device == "mps" else self.dtype

184
        torch.random.manual_seed(seed)
185
        pool_size = 2
186
        x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True)
187
188
        if not contiguous:
            x = x.permute(0, 1, 3, 2)
189
        rois = torch.tensor(
190
            [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device  # format is (xyxy)
191
        )
192

193
194
        def func(z):
            return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
195

196
        script_func = self.get_script_fn(rois, pool_size)
197

198
        with DeterministicGuard(deterministic):
199
200
201
202
203
204
205
206
207
208
209
            gradcheck(func, (x,), atol=atol)

        gradcheck(script_func, (x,), atol=atol)

    @needs_mps
    def test_mps_error_inputs(self):
        pool_size = 2
        x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True)
        rois = torch.tensor(
            [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps"  # format is (xyxy)
        )
210

211
212
213
214
215
216
217
        def func(z):
            return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)

        with pytest.raises(
            RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs."
        ):
            gradcheck(func, (x,))
218

219
    @needs_cuda
220
221
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
222
223
224
    def test_autocast(self, x_dtype, rois_dtype):
        with torch.cuda.amp.autocast():
            self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
225
226
227

    def _helper_boxes_shape(self, func):
        # test boxes as Tensor[N, 5]
228
        with pytest.raises(AssertionError):
229
230
231
232
233
            a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
            boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
            func(a, boxes, output_size=(2, 2))

        # test boxes as List[Tensor[N, 4]]
234
        with pytest.raises(AssertionError):
235
236
237
238
            a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
            boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
            ops.roi_pool(a, [boxes], output_size=(2, 2))

239
240
241
242
243
244
245
246
    def _helper_jit_boxes_list(self, model):
        x = torch.rand(2, 1, 10, 10)
        roi = torch.tensor([[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], dtype=torch.float).t()
        rois = [roi, roi]
        scriped = torch.jit.script(model)
        y = scriped(x, rois)
        assert y.shape == (10, 1, 3, 3)

247
    @abstractmethod
248
249
    def fn(*args, **kwargs):
        pass
250

251
252
253
254
    @abstractmethod
    def make_obj(*args, **kwargs):
        pass

255
    @abstractmethod
256
257
    def get_script_fn(*args, **kwargs):
        pass
258

259
    @abstractmethod
260
261
    def expected_fn(*args, **kwargs):
        pass
262

263

264
class TestRoiPool(RoIOpTester):
265
266
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
267

268
269
270
271
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
        obj = ops.RoIPool((pool_h, pool_w), spatial_scale)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

272
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
273
274
        scriped = torch.jit.script(ops.roi_pool)
        return lambda x: scriped(x, rois, pool_size)
275

276
277
278
    def expected_fn(
        self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
    ):
279
280
        if device is None:
            device = torch.device("cpu")
281

282
283
        n_channels = x.size(1)
        y = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
284

285
286
        def get_slice(k, block):
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))
287

288
289
290
        for roi_idx, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
291
            roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
292

293
294
295
            roi_h, roi_w = roi_x.shape[-2:]
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w
296

297
298
299
300
301
302
            for i in range(0, pool_h):
                for j in range(0, pool_w):
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
                    if bin_x.numel() > 0:
                        y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0]
        return y
303

304
    def test_boxes_shape(self):
305
306
        self._helper_boxes_shape(ops.roi_pool)

307
308
309
310
    def test_jit_boxes_list(self):
        model = PoolWrapper(ops.RoIPool(output_size=[3, 3], spatial_scale=1.0))
        self._helper_jit_boxes_list(model)

311

312
class TestPSRoIPool(RoIOpTester):
313
314
    mps_backward_atol = 5e-2

315
316
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
317

318
319
320
321
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
        obj = ops.PSRoIPool((pool_h, pool_w), spatial_scale)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

322
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
323
324
        scriped = torch.jit.script(ops.ps_roi_pool)
        return lambda x: scriped(x, rois, pool_size)
325

326
327
328
    def expected_fn(
        self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
    ):
329
330
331
        if device is None:
            device = torch.device("cpu")
        n_input_channels = x.size(1)
332
        assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
333
334
335
336
337
338
339
340
341
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
        y = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)

        def get_slice(k, block):
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))

        for roi_idx, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
342
            roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357

            roi_height = max(i_end - i_begin, 1)
            roi_width = max(j_end - j_begin, 1)
            bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w)

            for i in range(0, pool_h):
                for j in range(0, pool_w):
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
                    if bin_x.numel() > 0:
                        area = bin_x.size(-2) * bin_x.size(-1)
                        for c_out in range(0, n_output_channels):
                            c_in = c_out * (pool_h * pool_w) + pool_w * i + j
                            t = torch.sum(bin_x[c_in, :, :])
                            y[roi_idx, c_out, i, j] = t / area
        return y
358

359
    def test_boxes_shape(self):
360
361
        self._helper_boxes_shape(ops.ps_roi_pool)

362

363
364
def bilinear_interpolate(data, y, x, snap_border=False):
    height, width = data.shape
365

366
367
368
369
370
    if snap_border:
        if -1 < y <= 0:
            y = 0
        elif height - 1 <= y < height:
            y = height - 1
371

372
373
374
375
        if -1 < x <= 0:
            x = 0
        elif width - 1 <= x < width:
            x = width - 1
376

377
378
379
380
    y_low = int(math.floor(y))
    x_low = int(math.floor(x))
    y_high = y_low + 1
    x_high = x_low + 1
381

382
383
    wy_h = y - y_low
    wx_h = x - x_low
384
    wy_l = 1 - wy_h
385
    wx_l = 1 - wx_h
386

387
    val = 0
388
389
390
391
    for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
        for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
            if 0 <= yp < height and 0 <= xp < width:
                val += wx * wy * data[yp, xp]
392
    return val
393
394


395
class TestRoIAlign(RoIOpTester):
396
397
    mps_backward_atol = 6e-2

AhnDW's avatar
AhnDW committed
398
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
399
400
401
        return ops.RoIAlign(
            (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
        )(x, rois)
402

403
404
405
406
407
408
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, aligned=False, wrap=False):
        obj = ops.RoIAlign(
            (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
        )
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

409
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
410
411
        scriped = torch.jit.script(ops.roi_align)
        return lambda x: scriped(x, rois, pool_size)
412

413
414
415
416
417
418
419
420
421
422
423
424
    def expected_fn(
        self,
        in_data,
        rois,
        pool_h,
        pool_w,
        spatial_scale=1,
        sampling_ratio=-1,
        aligned=False,
        device=None,
        dtype=torch.float64,
    ):
425
426
        if device is None:
            device = torch.device("cpu")
427
428
429
        n_channels = in_data.size(1)
        out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)

430
        offset = 0.5 if aligned else 0.0
AhnDW's avatar
AhnDW committed
431

432
433
        for r, roi in enumerate(rois):
            batch_idx = int(roi[0])
AhnDW's avatar
AhnDW committed
434
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - offset for x in roi[1:])
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453

            roi_h = i_end - i_begin
            roi_w = j_end - j_begin
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w

            for i in range(0, pool_h):
                start_h = i_begin + i * bin_h
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
                for j in range(0, pool_w):
                    start_w = j_begin + j * bin_w
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))

                    for channel in range(0, n_channels):
                        val = 0
                        for iy in range(0, grid_h):
                            y = start_h + (iy + 0.5) * bin_h / grid_h
                            for ix in range(0, grid_w):
                                x = start_w + (ix + 0.5) * bin_w / grid_w
454
                                val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True)
455
456
457
                        val /= grid_h * grid_w

                        out_data[r, channel, i, j] = val
458
459
        return out_data

460
    def test_boxes_shape(self):
461
462
        self._helper_boxes_shape(ops.roi_align)

463
    @pytest.mark.parametrize("aligned", (True, False))
464
465
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
    @pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64), ids=str)
466
    @pytest.mark.parametrize("contiguous", (True, False))
467
    @pytest.mark.parametrize("deterministic", (True, False))
468
    def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None):
469
470
        if deterministic and device == "cpu":
            pytest.skip("cpu is always deterministic, don't retest")
471
        super().test_forward(
472
473
474
475
476
477
            device=device,
            contiguous=contiguous,
            deterministic=deterministic,
            x_dtype=x_dtype,
            rois_dtype=rois_dtype,
            aligned=aligned,
478
        )
479

480
    @needs_cuda
481
    @pytest.mark.parametrize("aligned", (True, False))
482
    @pytest.mark.parametrize("deterministic", (True, False))
483
484
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
485
    def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
486
        with torch.cuda.amp.autocast():
487
            self.test_forward(
488
489
490
491
492
493
                torch.device("cuda"),
                contiguous=False,
                deterministic=deterministic,
                aligned=aligned,
                x_dtype=x_dtype,
                rois_dtype=rois_dtype,
494
            )
495

496
    @pytest.mark.parametrize("seed", range(10))
497
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
498
499
500
501
502
503
504
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("deterministic", (True, False))
    def test_backward(self, seed, device, contiguous, deterministic):
        if deterministic and device == "cpu":
            pytest.skip("cpu is always deterministic, don't retest")
        super().test_backward(seed, device, contiguous, deterministic)

505
506
507
508
509
510
    def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
        rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
        rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,))  # set batch index
        rois[:, 3:] += rois[:, 1:3]  # make sure boxes aren't degenerate
        return rois

511
512
513
    @pytest.mark.parametrize("aligned", (True, False))
    @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 10), (0.1, 50)))
    @pytest.mark.parametrize("qdtype", (torch.qint8, torch.quint8, torch.qint32))
514
    def test_qroialign(self, aligned, scale, zero_point, qdtype):
515
516
517
518
519
520
521
        """Make sure quantized version of RoIAlign is close to float version"""
        pool_size = 5
        img_size = 10
        n_channels = 2
        num_imgs = 1
        dtype = torch.float

522
523
524
525
526
527
528
        x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype)
        qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qdtype)

        rois = self._make_rois(img_size, num_imgs, dtype)
        qrois = torch.quantize_per_tensor(rois, scale=scale, zero_point=zero_point, dtype=qdtype)

        x, rois = qx.dequantize(), qrois.dequantize()  # we want to pass the same inputs
529

530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        y = ops.roi_align(
            x,
            rois,
            output_size=pool_size,
            spatial_scale=1,
            sampling_ratio=-1,
            aligned=aligned,
        )
        qy = ops.roi_align(
            qx,
            qrois,
            output_size=pool_size,
            spatial_scale=1,
            sampling_ratio=-1,
            aligned=aligned,
        )

        # The output qy is itself a quantized tensor and there might have been a loss of info when it was
        # quantized. For a fair comparison we need to quantize y as well
        quantized_float_y = torch.quantize_per_tensor(y, scale=scale, zero_point=zero_point, dtype=qdtype)

        try:
            # Ideally, we would assert this, which passes with (scale, zero) == (1, 0)
            assert (qy == quantized_float_y).all()
        except AssertionError:
            # But because the computation aren't exactly the same between the 2 RoIAlign procedures, some
            # rounding error may lead to a difference of 2 in the output.
            # For example with (scale, zero) = (2, 10), 45.00000... will be quantized to 44
            # but 45.00000001 will be rounded to 46. We make sure below that:
            # - such discrepancies between qy and quantized_float_y are very rare (less then 5%)
            # - any difference between qy and quantized_float_y is == scale
            diff_idx = torch.where(qy != quantized_float_y)
            num_diff = diff_idx[0].numel()
563
            assert num_diff / qy.numel() < 0.05
564
565
566
567
568
569
570

            abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize())
            t_scale = torch.full_like(abs_diff, fill_value=scale)
            torch.testing.assert_close(abs_diff, t_scale, rtol=1e-5, atol=1e-5)

    def test_qroi_align_multiple_images(self):
        dtype = torch.float
571
572
        x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype)
        qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8)
573
        rois = self._make_rois(img_size=10, num_imgs=2, dtype=dtype, num_rois=10)
574
        qrois = torch.quantize_per_tensor(rois, scale=1, zero_point=0, dtype=torch.qint8)
575
576
        with pytest.raises(RuntimeError, match="Only one image per batch is allowed"):
            ops.roi_align(qx, qrois, output_size=5)
577

578
579
580
581
    def test_jit_boxes_list(self):
        model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1))
        self._helper_jit_boxes_list(model)

582

583
class TestPSRoIAlign(RoIOpTester):
584
585
    mps_backward_atol = 5e-2

586
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
587
        return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
588

589
590
591
592
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, wrap=False):
        obj = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

593
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
594
595
        scriped = torch.jit.script(ops.ps_roi_align)
        return lambda x: scriped(x, rois, pool_size)
596

597
598
599
    def expected_fn(
        self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, sampling_ratio=-1, dtype=torch.float64
    ):
600
601
        if device is None:
            device = torch.device("cpu")
602
        n_input_channels = in_data.size(1)
603
        assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
        out_data = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)

        for r, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - 0.5 for x in roi[1:])

            roi_h = i_end - i_begin
            roi_w = j_end - j_begin
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w

            for i in range(0, pool_h):
                start_h = i_begin + i * bin_h
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
                for j in range(0, pool_w):
                    start_w = j_begin + j * bin_w
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
                    for c_out in range(0, n_output_channels):
                        c_in = c_out * (pool_h * pool_w) + pool_w * i + j

                        val = 0
                        for iy in range(0, grid_h):
                            y = start_h + (iy + 0.5) * bin_h / grid_h
                            for ix in range(0, grid_w):
                                x = start_w + (ix + 0.5) * bin_w / grid_w
630
                                val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True)
631
632
633
634
                        val /= grid_h * grid_w

                        out_data[r, c_out, i, j] = val
        return out_data
635

636
    def test_boxes_shape(self):
637
638
        self._helper_boxes_shape(ops.ps_roi_align)

639

640
class TestMultiScaleRoIAlign:
641
642
643
644
645
646
    def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False):
        if fmap_names is None:
            fmap_names = ["0"]
        obj = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)
        return MultiScaleRoIAlignModuleWrapper(obj) if wrap else obj

647
    def test_msroialign_repr(self):
648
        fmap_names = ["0"]
649
650
651
        output_size = (7, 7)
        sampling_ratio = 2
        # Pass mock feature map names
652
        t = self.make_obj(fmap_names, output_size, sampling_ratio, wrap=False)
653
654

        # Check integrity of object __repr__ attribute
655
656
657
658
        expected_string = (
            f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, "
            f"sampling_ratio={sampling_ratio})"
        )
659
        assert repr(t) == expected_string
660

661
    @pytest.mark.parametrize("device", cpu_and_cuda())
662
663
664
665
666
667
668
669
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

670

671
672
class TestNMS:
    def _reference_nms(self, boxes, scores, iou_threshold):
673
674
        """
        Args:
675
676
677
            boxes: boxes in corner-form
            scores: probabilities
            iou_threshold: intersection over union threshold
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
        Returns:
             picked: a list of indexes of the kept boxes
        """
        picked = []
        _, indexes = scores.sort(descending=True)
        while len(indexes) > 0:
            current = indexes[0]
            picked.append(current.item())
            if len(indexes) == 1:
                break
            current_box = boxes[current, :]
            indexes = indexes[1:]
            rest_boxes = boxes[indexes, :]
            iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1)
            indexes = indexes[iou <= iou_threshold]

        return torch.as_tensor(picked)

696
697
698
699
700
    def _create_tensors_with_iou(self, N, iou_thresh):
        # force last box to have a pre-defined iou with the first box
        # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
        # then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
        # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
701
702
703
        # Adjust the threshold upward a bit with the intent of creating
        # at least one box that exceeds (barely) the threshold and so
        # should be suppressed.
704
        boxes = torch.rand(N, 4) * 100
705
706
707
        boxes[:, 2:] += boxes[:, :2]
        boxes[-1, :] = boxes[0, :]
        x0, y0, x1, y1 = boxes[-1].tolist()
708
        iou_thresh += 1e-5
709
        boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
710
711
712
        scores = torch.rand(N)
        return boxes, scores

713
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
714
715
716
    @pytest.mark.parametrize("seed", range(10))
    def test_nms_ref(self, iou, seed):
        torch.random.manual_seed(seed)
717
        err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
718
719
720
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        keep_ref = self._reference_nms(boxes, scores, iou)
        keep = ops.nms(boxes, scores, iou)
721
        torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou))
722
723
724
725
726
727
728
729
730
731
732

    def test_nms_input_errors(self):
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(4), torch.rand(3), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 5), torch.rand(3), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 4), torch.rand(4), 0.5)

733
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
734
735
    @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10)))
    def test_qnms(self, iou, scale, zero_point):
736
        # Note: we compare qnms vs nms instead of qnms vs reference implementation.
737
        # This is because with the int conversion, the trick used in _create_tensors_with_iou
738
        # doesn't really work (in fact, nms vs reference implem will also fail with ints)
739
        err_msg = "NMS and QNMS give different results for IoU={}"
740
        boxes, scores = self._create_tensors_with_iou(1000, iou)
741
        scores *= 100  # otherwise most scores would be 0 or 1 after int conversion
742

743
744
        qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point, dtype=torch.quint8)
        qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point, dtype=torch.quint8)
745

746
747
        boxes = qboxes.dequantize()
        scores = qscores.dequantize()
748

749
750
        keep = ops.nms(boxes, scores, iou)
        qkeep = ops.nms(qboxes, qscores, iou)
751

752
        torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
753

754
755
756
757
758
759
760
    @pytest.mark.parametrize(
        "device",
        (
            pytest.param("cuda", marks=pytest.mark.needs_cuda),
            pytest.param("mps", marks=pytest.mark.needs_mps),
        ),
    )
761
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
762
763
    def test_nms_gpu(self, iou, device, dtype=torch.float64):
        dtype = torch.float32 if device == "mps" else dtype
764
        tol = 1e-3 if dtype is torch.half else 1e-5
765
        err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
766

767
768
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        r_cpu = ops.nms(boxes, scores, iou)
769
        r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)
770

771
        is_eq = torch.allclose(r_cpu, r_gpu.cpu())
772
773
774
        if not is_eq:
            # if the indices are not the same, ensure that it's because the scores
            # are duplicate
775
            is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
776
777
778
        assert is_eq, err_msg.format(iou)

    @needs_cuda
779
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
780
781
782
    @pytest.mark.parametrize("dtype", (torch.float, torch.half))
    def test_autocast(self, iou, dtype):
        with torch.cuda.amp.autocast():
783
784
785
786
787
788
789
790
791
792
            self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")

    @pytest.mark.parametrize(
        "device",
        (
            pytest.param("cuda", marks=pytest.mark.needs_cuda),
            pytest.param("mps", marks=pytest.mark.needs_mps),
        ),
    )
    def test_nms_float16(self, device):
793
794
795
796
797
798
        boxes = torch.tensor(
            [
                [285.3538, 185.5758, 1193.5110, 851.4551],
                [285.1472, 188.7374, 1192.4984, 851.0669],
                [279.2440, 197.9812, 1189.4746, 849.2019],
            ]
799
800
        ).to(device)
        scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device)
801
802
803
804

        iou_thres = 0.2
        keep32 = ops.nms(boxes, scores, iou_thres)
        keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
805
        assert_equal(keep32, keep16)
806

807
808
    @pytest.mark.parametrize("seed", range(10))
    def test_batched_nms_implementations(self, seed):
809
        """Make sure that both implementations of batched_nms yield identical results"""
810
        torch.random.manual_seed(seed)
811
812

        num_boxes = 1000
813
        iou_threshold = 0.9
814
815
816
817
818
819
820
821
822
823

        boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
        assert max(boxes[:, 0]) < min(boxes[:, 2])  # x1 < x2
        assert max(boxes[:, 1]) < min(boxes[:, 3])  # y1 < y2

        scores = torch.rand(num_boxes)
        idxs = torch.randint(0, 4, size=(num_boxes,))
        keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
        keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)

824
825
826
        torch.testing.assert_close(
            keep_vanilla, keep_trick, msg="The vanilla and the trick implementation yield different nms outputs."
        )
827
828
829

        # Also make sure an empty tensor is returned if boxes is empty
        empty = torch.empty((0,), dtype=torch.int64)
830
        torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None))
831

832

833
834
835
class TestDeformConv:
    dtype = torch.float64

836
    def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
        stride_h, stride_w = _pair(stride)
        pad_h, pad_w = _pair(padding)
        dil_h, dil_w = _pair(dilation)
        weight_h, weight_w = weight.shape[-2:]

        n_batches, n_in_channels, in_h, in_w = x.shape
        n_out_channels = weight.shape[0]

        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

        n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w)
        in_c_per_offset_grp = n_in_channels // n_offset_grps

        n_weight_grps = n_in_channels // weight.shape[1]
        in_c_per_weight_grp = weight.shape[1]
        out_c_per_weight_grp = n_out_channels // n_weight_grps

        out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
        for b in range(n_batches):
            for c_out in range(n_out_channels):
                for i in range(out_h):
                    for j in range(out_w):
                        for di in range(weight_h):
                            for dj in range(weight_w):
                                for c in range(in_c_per_weight_grp):
                                    weight_grp = c_out // out_c_per_weight_grp
                                    c_in = weight_grp * in_c_per_weight_grp + c

                                    offset_grp = c_in // in_c_per_offset_grp
867
868
                                    mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj
                                    offset_idx = 2 * mask_idx
869
870
871
872

                                    pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
                                    pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]

873
874
875
876
                                    mask_value = 1.0
                                    if mask is not None:
                                        mask_value = mask[b, mask_idx, i, j]

877
878
879
880
881
                                    out[b, c_out, i, j] += (
                                        mask_value
                                        * weight[c_out, c, di, dj]
                                        * bilinear_interpolate(x[b, c_in, :, :], pi, pj)
                                    )
882
883
884
        out += bias.view(1, n_out_channels, 1, 1)
        return out

885
    @lru_cache(maxsize=None)
886
    def get_fn_args(self, device, contiguous, batch_sz, dtype):
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
        n_in_channels = 6
        n_out_channels = 2
        n_weight_grps = 2
        n_offset_grps = 3

        stride = (2, 1)
        pad = (1, 0)
        dilation = (2, 1)

        stride_h, stride_w = stride
        pad_h, pad_w = pad
        dil_h, dil_w = dilation
        weight_h, weight_w = (3, 2)
        in_h, in_w = (5, 4)

        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

905
        x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True)
906

907
908
909
910
911
912
913
914
915
        offset = torch.randn(
            batch_sz,
            n_offset_grps * 2 * weight_h * weight_w,
            out_h,
            out_w,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
916

917
918
919
        mask = torch.randn(
            batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, device=device, dtype=dtype, requires_grad=True
        )
920

921
922
923
924
925
926
927
928
929
        weight = torch.randn(
            n_out_channels,
            n_in_channels // n_weight_grps,
            weight_h,
            weight_w,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
930

931
        bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True)
932
933
934
935

        if not contiguous:
            x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
            offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
936
            mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
937
938
            weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)

939
        return x, weight, offset, mask, bias, stride, pad, dilation
940

941
942
943
944
945
946
    def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, wrap=False):
        obj = ops.DeformConv2d(
            in_channels, out_channels, kernel_size, stride=(2, 1), padding=(1, 0), dilation=(2, 1), groups=groups
        )
        return DeformConvModuleWrapper(obj) if wrap else obj

947
    @pytest.mark.parametrize("device", cpu_and_cuda())
948
949
950
951
952
953
954
955
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

956
    @pytest.mark.parametrize("device", cpu_and_cuda())
957
958
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("batch_sz", (0, 33))
959
960
    def test_forward(self, device, contiguous, batch_sz, dtype=None):
        dtype = dtype or self.dtype
961
        x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
962
963
964
965
        in_channels = 6
        out_channels = 2
        kernel_size = (3, 2)
        groups = 2
Nicolas Hug's avatar
Nicolas Hug committed
966
        tol = 2e-3 if dtype is torch.half else 1e-5
967

968
969
970
        layer = self.make_obj(in_channels, out_channels, kernel_size, groups, wrap=False).to(
            device=x.device, dtype=dtype
        )
971
        res = layer(x, offset, mask)
972
973
974

        weight = layer.weight.data
        bias = layer.bias.data
975
976
        expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)

977
        torch.testing.assert_close(
978
            res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
979
        )
980
981
982
983

        # no modulation test
        res = layer(x, offset)
        expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
984

985
        torch.testing.assert_close(
986
            res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
987
        )
988

989
990
991
992
993
    def test_wrong_sizes(self):
        in_channels = 6
        out_channels = 2
        kernel_size = (3, 2)
        groups = 2
994
995
996
997
998
999
        x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(
            "cpu", contiguous=True, batch_sz=10, dtype=self.dtype
        )
        layer = ops.DeformConv2d(
            in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups
        )
1000
        with pytest.raises(RuntimeError, match="the shape of the offset"):
1001
            wrong_offset = torch.rand_like(offset[:, :2])
1002
            layer(x, wrong_offset)
1003

1004
        with pytest.raises(RuntimeError, match=r"mask.shape\[1\] is not valid"):
1005
            wrong_mask = torch.rand_like(mask[:, :2])
1006
            layer(x, offset, wrong_mask)
1007

1008
    @pytest.mark.parametrize("device", cpu_and_cuda())
1009
1010
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("batch_sz", (0, 33))
1011
    def test_backward(self, device, contiguous, batch_sz):
1012
1013
1014
        x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
            device, contiguous, batch_sz, self.dtype
        )
1015
1016

        def func(x_, offset_, mask_, weight_, bias_):
1017
1018
1019
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_
            )
1020

1021
        gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True)
1022
1023

        def func_no_mask(x_, offset_, weight_, bias_):
1024
1025
1026
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None
            )
1027

1028
        gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True)
1029
1030
1031
1032

        @torch.jit.script
        def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
            # type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=mask_
            )

        gradcheck(
            lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
            (x, offset, mask, weight, bias),
            nondet_tol=1e-5,
            fast_mode=True,
        )
1043
1044

        @torch.jit.script
1045
1046
        def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
            # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=None
            )

        gradcheck(
            lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
            (x, offset, weight, bias),
            nondet_tol=1e-5,
            fast_mode=True,
        )
1057

1058
    @needs_cuda
1059
    @pytest.mark.parametrize("contiguous", (True, False))
1060
    def test_compare_cpu_cuda_grads(self, contiguous):
1061
1062
1063
        # Test from https://github.com/pytorch/vision/issues/2598
        # Run on CUDA only

1064
1065
        # compare grads computed on CUDA with grads computed on CPU
        true_cpu_grads = None
1066

1067
1068
1069
1070
        init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
        img = torch.randn(8, 9, 1000, 110)
        offset = torch.rand(8, 2 * 3 * 3, 1000, 110)
        mask = torch.rand(8, 3 * 3, 1000, 110)
1071

1072
1073
1074
1075
1076
1077
1078
        if not contiguous:
            img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
            offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
            mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
            weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
        else:
            weight = init_weight
1079

1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
        for d in ["cpu", "cuda"]:
            out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
            out.mean().backward()
            if true_cpu_grads is None:
                true_cpu_grads = init_weight.grad
                assert true_cpu_grads is not None
            else:
                assert init_weight.grad is not None
                res_grads = init_weight.grad.to("cpu")
                torch.testing.assert_close(true_cpu_grads, res_grads)

    @needs_cuda
1092
1093
    @pytest.mark.parametrize("batch_sz", (0, 33))
    @pytest.mark.parametrize("dtype", (torch.float, torch.half))
1094
1095
1096
1097
    def test_autocast(self, batch_sz, dtype):
        with torch.cuda.amp.autocast():
            self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)

1098
1099
1100
1101
    def test_forward_scriptability(self):
        # Non-regression test for https://github.com/pytorch/vision/issues/4078
        torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))

1102
1103

class TestFrozenBNT:
1104
1105
    def test_frozenbatchnorm2d_repr(self):
        num_features = 32
1106
1107
        eps = 1e-5
        t = ops.misc.FrozenBatchNorm2d(num_features, eps=eps)
1108
1109

        # Check integrity of object __repr__ attribute
1110
        expected_string = f"FrozenBatchNorm2d({num_features}, eps={eps})"
1111
        assert repr(t) == expected_string
1112

1113
1114
1115
    @pytest.mark.parametrize("seed", range(10))
    def test_frozenbatchnorm2d_eps(self, seed):
        torch.random.manual_seed(seed)
1116
1117
        sample_size = (4, 32, 28, 28)
        x = torch.rand(sample_size)
1118
1119
1120
1121
1122
1123
1124
        state_dict = dict(
            weight=torch.rand(sample_size[1]),
            bias=torch.rand(sample_size[1]),
            running_mean=torch.rand(sample_size[1]),
            running_var=torch.rand(sample_size[1]),
            num_batches_tracked=torch.tensor(100),
        )
1125

1126
        # Check that default eps is equal to the one of BN
1127
1128
        fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
        fbn.load_state_dict(state_dict, strict=False)
1129
        bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
1130
1131
        bn.load_state_dict(state_dict)
        # Difference is expected to fall in an acceptable range
1132
        torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
1133
1134
1135
1136
1137
1138

        # Check computation for eps > 0
        fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
        fbn.load_state_dict(state_dict, strict=False)
        bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
        bn.load_state_dict(state_dict)
1139
        torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
1140

1141

Aditya Oke's avatar
Aditya Oke committed
1142
class TestBoxConversionToRoi:
1143
1144
1145
    def _get_box_sequences():
        # Define here the argument type of `boxes` supported by region pooling operations
        box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float)
1146
1147
1148
1149
        box_list = [
            torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
            torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
        ]
1150
1151
1152
        box_tuple = tuple(box_list)
        return box_tensor, box_list, box_tuple

1153
    @pytest.mark.parametrize("box_sequence", _get_box_sequences())
1154
    def test_check_roi_boxes_shape(self, box_sequence):
1155
        # Ensure common sequences of tensors are supported
1156
        ops._utils.check_roi_boxes_shape(box_sequence)
1157

1158
    @pytest.mark.parametrize("box_sequence", _get_box_sequences())
1159
    def test_convert_boxes_to_roi_format(self, box_sequence):
1160
1161
        # Ensure common sequences of tensors yield the same result
        ref_tensor = None
1162
1163
1164
1165
        if ref_tensor is None:
            ref_tensor = box_sequence
        else:
            assert_equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence))
1166
1167


Aditya Oke's avatar
Aditya Oke committed
1168
class TestBoxConvert:
1169
    def test_bbox_same(self):
1170
1171
1172
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
1173

1174
        exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
1175

1176
1177
1178
1179
        assert exp_xyxy.size() == torch.Size([4, 4])
        assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy)
        assert_equal(ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh"), exp_xyxy)
        assert_equal(ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh"), exp_xyxy)
1180
1181
1182
1183

    def test_bbox_xyxy_xywh(self):
        # Simple test convert boxes to xywh and back. Make sure they are same.
        # box_tensor is in x1 y1 x2 y2 format.
1184
1185
1186
1187
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
        exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
1188

1189
        assert exp_xywh.size() == torch.Size([4, 4])
1190
        box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
1191
        assert_equal(box_xywh, exp_xywh)
1192
1193
1194

        # Reverse conversion
        box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy")
1195
        assert_equal(box_xyxy, box_tensor)
1196
1197

    def test_bbox_xyxy_cxcywh(self):
Aditya Oke's avatar
Aditya Oke committed
1198
        # Simple test convert boxes to cxcywh and back. Make sure they are same.
1199
        # box_tensor is in x1 y1 x2 y2 format.
1200
1201
1202
1203
1204
1205
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
        exp_cxcywh = torch.tensor(
            [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
        )
1206

1207
        assert exp_cxcywh.size() == torch.Size([4, 4])
1208
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
1209
        assert_equal(box_cxcywh, exp_cxcywh)
1210
1211
1212

        # Reverse conversion
        box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy")
1213
        assert_equal(box_xyxy, box_tensor)
1214
1215

    def test_bbox_xywh_cxcywh(self):
1216
1217
1218
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
        )
1219

1220
1221
1222
        exp_cxcywh = torch.tensor(
            [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
        )
1223

1224
        assert exp_cxcywh.size() == torch.Size([4, 4])
1225
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh")
1226
        assert_equal(box_cxcywh, exp_cxcywh)
1227
1228
1229

        # Reverse conversion
        box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
1230
        assert_equal(box_xywh, box_tensor)
1231

1232
1233
    @pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh"])
    @pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy"])
1234
    def test_bbox_invalid(self, inv_infmt, inv_outfmt):
1235
1236
1237
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
        )
1238

1239
1240
        with pytest.raises(ValueError):
            ops.box_convert(box_tensor, inv_infmt, inv_outfmt)
1241
1242

    def test_bbox_convert_jit(self):
1243
1244
1245
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
1246

1247
        scripted_fn = torch.jit.script(ops.box_convert)
1248

1249
        box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
1250
        scripted_xywh = scripted_fn(box_tensor, "xyxy", "xywh")
Aditya Oke's avatar
Aditya Oke committed
1251
        torch.testing.assert_close(scripted_xywh, box_xywh)
1252

1253
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
1254
        scripted_cxcywh = scripted_fn(box_tensor, "xyxy", "cxcywh")
Aditya Oke's avatar
Aditya Oke committed
1255
        torch.testing.assert_close(scripted_cxcywh, box_cxcywh)
1256
1257


Aditya Oke's avatar
Aditya Oke committed
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
class TestBoxArea:
    def area_check(self, box, expected, atol=1e-4):
        out = ops.box_area(box)
        torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)

    @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
    def test_int_boxes(self, dtype):
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
        expected = torch.tensor([10000, 0], dtype=torch.int32)
        self.area_check(box_tensor, expected)

    @pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
    def test_float_boxes(self, dtype):
        box_tensor = torch.tensor(FLOAT_BOXES, dtype=dtype)
        expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
        self.area_check(box_tensor, expected)

    def test_float16_box(self):
        box_tensor = torch.tensor(
            [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16
        )

        expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
        self.area_check(box_tensor, expected, atol=0.01)
1282

Aditya Oke's avatar
Aditya Oke committed
1283
1284
1285
1286
1287
1288
    def test_box_area_jit(self):
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
        expected = ops.box_area(box_tensor)
        scripted_fn = torch.jit.script(ops.box_area)
        scripted_area = scripted_fn(box_tensor)
        torch.testing.assert_close(scripted_area, expected)
1289

Aditya Oke's avatar
Aditya Oke committed
1290

1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
FLOAT_BOXES = [
    [285.3538, 185.5758, 1193.5110, 851.4551],
    [285.1472, 188.7374, 1192.4984, 851.0669],
    [279.2440, 197.9812, 1189.4746, 849.2019],
]


def gen_box(size, dtype=torch.float):
    xy1 = torch.rand((size, 2), dtype=dtype)
    xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
    return torch.cat([xy1, xy2], axis=-1)


Aditya Oke's avatar
Aditya Oke committed
1306
1307
class TestIouBase:
    @staticmethod
1308
    def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
1309
        for dtype in dtypes:
1310
1311
            actual_box1 = torch.tensor(actual_box1, dtype=dtype)
            actual_box2 = torch.tensor(actual_box2, dtype=dtype)
1312
            expected_box = torch.tensor(expected)
1313
            out = target_fn(actual_box1, actual_box2)
Aditya Oke's avatar
Aditya Oke committed
1314
            torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
Aditya Oke's avatar
Aditya Oke committed
1315

Aditya Oke's avatar
Aditya Oke committed
1316
    @staticmethod
1317
1318
    def _run_jit_test(target_fn: Callable, actual_box: List):
        box_tensor = torch.tensor(actual_box, dtype=torch.float)
Aditya Oke's avatar
Aditya Oke committed
1319
1320
1321
1322
        expected = target_fn(box_tensor, box_tensor)
        scripted_fn = torch.jit.script(target_fn)
        scripted_out = scripted_fn(box_tensor, box_tensor)
        torch.testing.assert_close(scripted_out, expected)
Aditya Oke's avatar
Aditya Oke committed
1323

1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
    @staticmethod
    def _cartesian_product(boxes1, boxes2, target_fn: Callable):
        N = boxes1.size(0)
        M = boxes2.size(0)
        result = torch.zeros((N, M))
        for i in range(N):
            for j in range(M):
                result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
        return result

    @staticmethod
    def _run_cartesian_test(target_fn: Callable):
        boxes1 = gen_box(5)
        boxes2 = gen_box(7)
        a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
        b = target_fn(boxes1, boxes2)
1340
        torch.testing.assert_close(a, b)
1341

1342

Aditya Oke's avatar
Aditya Oke committed
1343
class TestBoxIou(TestIouBase):
1344
    int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
Aditya Oke's avatar
Aditya Oke committed
1345
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
Aditya Oke's avatar
Aditya Oke committed
1346

Aditya Oke's avatar
Aditya Oke committed
1347
    @pytest.mark.parametrize(
1348
        "actual_box1, actual_box2, dtypes, atol, expected",
Aditya Oke's avatar
Aditya Oke committed
1349
        [
1350
1351
1352
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
Aditya Oke's avatar
Aditya Oke committed
1353
1354
        ],
    )
1355
1356
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)
Aditya Oke's avatar
Aditya Oke committed
1357

Aditya Oke's avatar
Aditya Oke committed
1358
1359
    def test_iou_jit(self):
        self._run_jit_test(ops.box_iou, INT_BOXES)
Aditya Oke's avatar
Aditya Oke committed
1360

1361
1362
1363
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.box_iou)

1364

Aditya Oke's avatar
Aditya Oke committed
1365
class TestGeneralizedBoxIou(TestIouBase):
1366
    int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
Aditya Oke's avatar
Aditya Oke committed
1367
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1368
1369

    @pytest.mark.parametrize(
1370
        "actual_box1, actual_box2, dtypes, atol, expected",
1371
        [
1372
1373
1374
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1375
1376
        ],
    )
1377
1378
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.generalized_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1379

Aditya Oke's avatar
Aditya Oke committed
1380
1381
    def test_iou_jit(self):
        self._run_jit_test(ops.generalized_box_iou, INT_BOXES)
1382

1383
1384
1385
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.generalized_box_iou)

1386

Aditya Oke's avatar
Aditya Oke committed
1387
class TestDistanceBoxIoU(TestIouBase):
1388
1389
1390
1391
1392
1393
    int_expected = [
        [1.0000, 0.1875, -0.4444],
        [0.1875, 1.0000, -0.5625],
        [-0.4444, -0.5625, 1.0000],
        [-0.0781, 0.1875, -0.6267],
    ]
Aditya Oke's avatar
Aditya Oke committed
1394
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1395

Aditya Oke's avatar
Aditya Oke committed
1396
    @pytest.mark.parametrize(
1397
        "actual_box1, actual_box2, dtypes, atol, expected",
Aditya Oke's avatar
Aditya Oke committed
1398
        [
1399
1400
1401
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
Aditya Oke's avatar
Aditya Oke committed
1402
1403
        ],
    )
1404
1405
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.distance_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1406

Aditya Oke's avatar
Aditya Oke committed
1407
1408
    def test_iou_jit(self):
        self._run_jit_test(ops.distance_box_iou, INT_BOXES)
1409

1410
1411
1412
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.distance_box_iou)

1413

Aditya Oke's avatar
Aditya Oke committed
1414
class TestCompleteBoxIou(TestIouBase):
1415
1416
1417
1418
1419
1420
    int_expected = [
        [1.0000, 0.1875, -0.4444],
        [0.1875, 1.0000, -0.5625],
        [-0.4444, -0.5625, 1.0000],
        [-0.0781, 0.1875, -0.6267],
    ]
Aditya Oke's avatar
Aditya Oke committed
1421
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1422
1423

    @pytest.mark.parametrize(
1424
        "actual_box1, actual_box2, dtypes, atol, expected",
1425
        [
1426
1427
1428
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1429
1430
        ],
    )
1431
1432
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.complete_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1433

Aditya Oke's avatar
Aditya Oke committed
1434
1435
    def test_iou_jit(self):
        self._run_jit_test(ops.complete_box_iou, INT_BOXES)
1436

1437
1438
1439
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.complete_box_iou)

1440

Aditya Oke's avatar
Aditya Oke committed
1441
1442
1443
1444
1445
def get_boxes(dtype, device):
    box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
    box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
    box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
    box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)
1446

Aditya Oke's avatar
Aditya Oke committed
1447
1448
    box1s = torch.stack([box2, box2], dim=0)
    box2s = torch.stack([box3, box4], dim=0)
1449

Aditya Oke's avatar
Aditya Oke committed
1450
    return box1, box2, box3, box4, box1s, box2s
1451

Aditya Oke's avatar
Aditya Oke committed
1452

Aditya Oke's avatar
Aditya Oke committed
1453
1454
1455
1456
def assert_iou_loss(iou_fn, box1, box2, expected_loss, device, reduction="none"):
    computed_loss = iou_fn(box1, box2, reduction=reduction)
    expected_loss = torch.tensor(expected_loss, device=device)
    torch.testing.assert_close(computed_loss, expected_loss)
1457
1458


Aditya Oke's avatar
Aditya Oke committed
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
def assert_empty_loss(iou_fn, dtype, device):
    box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
    box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
    loss = iou_fn(box1, box2, reduction="mean")
    loss.backward()
    torch.testing.assert_close(loss, torch.tensor(0.0, device=device))
    assert box1.grad is not None, "box1.grad should not be None after backward is called"
    assert box2.grad is not None, "box2.grad should not be None after backward is called"
    loss = iou_fn(box1, box2, reduction="none")
    assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty"
Aditya Oke's avatar
Aditya Oke committed
1469

Aditya Oke's avatar
Aditya Oke committed
1470

Aditya Oke's avatar
Aditya Oke committed
1471
1472
class TestGeneralizedBoxIouLoss:
    # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py
1473
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1474
1475
1476
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_giou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1477

Aditya Oke's avatar
Aditya Oke committed
1478
1479
        # Identical boxes should have loss of 0
        assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1480

Aditya Oke's avatar
Aditya Oke committed
1481
1482
        # quarter size box inside other box = IoU of 0.25
        assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1483

Aditya Oke's avatar
Aditya Oke committed
1484
1485
1486
        # Two side by side boxes, area=union
        # IoU=0 and GIoU=0 (loss 1.0)
        assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1487

Aditya Oke's avatar
Aditya Oke committed
1488
1489
1490
        # Two diagonally adjacent boxes, area=2*union
        # IoU=0 and GIoU=-0.5 (loss 1.5)
        assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1491

Aditya Oke's avatar
Aditya Oke committed
1492
1493
1494
        # Test batched loss and reductions
        assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, device=device, reduction="sum")
        assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, device=device, reduction="mean")
Yassine Alouini's avatar
Yassine Alouini committed
1495

1496
1497
1498
1499
1500
        # Test reduction value
        # reduction value other than ["none", "mean", "sum"] should raise a ValueError
        with pytest.raises(ValueError, match="Invalid"):
            ops.generalized_box_iou_loss(box1s, box2s, reduction="xyz")

1501
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1502
1503
1504
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_inputs(self, dtype, device):
        assert_empty_loss(ops.generalized_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1505
1506


Aditya Oke's avatar
Aditya Oke committed
1507
1508
class TestCompleteBoxIouLoss:
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1509
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1510
1511
    def test_ciou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1512

Aditya Oke's avatar
Aditya Oke committed
1513
1514
1515
1516
1517
1518
        assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
        assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
Yassine Alouini's avatar
Yassine Alouini committed
1519

1520
1521
1522
        with pytest.raises(ValueError, match="Invalid"):
            ops.complete_box_iou_loss(box1s, box2s, reduction="xyz")

1523
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1524
1525
1526
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_inputs(self, dtype, device):
        assert_empty_loss(ops.complete_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1527
1528


Aditya Oke's avatar
Aditya Oke committed
1529
class TestDistanceBoxIouLoss:
1530
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1531
1532
1533
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_distance_iou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1534

Aditya Oke's avatar
Aditya Oke committed
1535
1536
1537
1538
1539
1540
        assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
        assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
Yassine Alouini's avatar
Yassine Alouini committed
1541

1542
1543
1544
        with pytest.raises(ValueError, match="Invalid"):
            ops.distance_box_iou_loss(box1s, box2s, reduction="xyz")

1545
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1546
1547
1548
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_distance_iou_inputs(self, dtype, device):
        assert_empty_loss(ops.distance_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1549
1550


Aditya Oke's avatar
Aditya Oke committed
1551
1552
1553
1554
class TestFocalLoss:
    def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs):
        def logit(p):
            return torch.log(p / (1 - p))
Yassine Alouini's avatar
Yassine Alouini committed
1555

Aditya Oke's avatar
Aditya Oke committed
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
        def generate_tensor_with_range_type(shape, range_type, **kwargs):
            if range_type != "random_binary":
                low, high = {
                    "small": (0.0, 0.2),
                    "big": (0.8, 1.0),
                    "zeros": (0.0, 0.0),
                    "ones": (1.0, 1.0),
                    "random": (0.0, 1.0),
                }[range_type]
                return torch.testing.make_tensor(shape, low=low, high=high, **kwargs)
            else:
                return torch.randint(0, 2, shape, **kwargs)
Yassine Alouini's avatar
Yassine Alouini committed
1568

Aditya Oke's avatar
Aditya Oke committed
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
        # This function will return inputs and targets with shape: (shape[0]*9, shape[1])
        inputs = []
        targets = []
        for input_range_type, target_range_type in [
            ("small", "zeros"),
            ("small", "ones"),
            ("small", "random_binary"),
            ("big", "zeros"),
            ("big", "ones"),
            ("big", "random_binary"),
            ("random", "zeros"),
            ("random", "ones"),
            ("random", "random_binary"),
        ]:
            inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs)))
            targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs))
Yassine Alouini's avatar
Yassine Alouini committed
1585

Aditya Oke's avatar
Aditya Oke committed
1586
        return torch.cat(inputs), torch.cat(targets)
Yassine Alouini's avatar
Yassine Alouini committed
1587

Aditya Oke's avatar
Aditya Oke committed
1588
1589
    @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
    @pytest.mark.parametrize("gamma", [0, 2])
1590
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [0, 1])
    def test_correct_ratio(self, alpha, gamma, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        # For testing the ratio with manual calculation, we require the reduction to be "none"
        reduction = "none"
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)
Yassine Alouini's avatar
Yassine Alouini committed
1602

Aditya Oke's avatar
Aditya Oke committed
1603
1604
1605
        assert torch.all(
            focal_loss <= ce_loss
        ), "focal loss must be less or equal to cross entropy loss with same input"
Abhijit Deo's avatar
Abhijit Deo committed
1606

Aditya Oke's avatar
Aditya Oke committed
1607
1608
1609
1610
1611
1612
1613
        loss_ratio = (focal_loss / ce_loss).squeeze()
        prob = torch.sigmoid(inputs)
        p_t = prob * targets + (1 - prob) * (1 - targets)
        correct_ratio = (1.0 - p_t) ** gamma
        if alpha >= 0:
            alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
            correct_ratio = correct_ratio * alpha_t
Abhijit Deo's avatar
Abhijit Deo committed
1614

Aditya Oke's avatar
Aditya Oke committed
1615
1616
        tol = 1e-3 if dtype is torch.half else 1e-5
        torch.testing.assert_close(correct_ratio, loss_ratio, atol=tol, rtol=tol)
Abhijit Deo's avatar
Abhijit Deo committed
1617

Aditya Oke's avatar
Aditya Oke committed
1618
    @pytest.mark.parametrize("reduction", ["mean", "sum"])
1619
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [2, 3])
    def test_equal_ce_loss(self, reduction, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        # focal loss should be equal ce_loss if alpha=-1 and gamma=0
        alpha = -1
        gamma = 0
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        inputs_fl = inputs.clone().requires_grad_()
        targets_fl = targets.clone()
        inputs_ce = inputs.clone().requires_grad_()
        targets_ce = targets.clone()
        focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction)
        ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction)
Abhijit Deo's avatar
Abhijit Deo committed
1636

Aditya Oke's avatar
Aditya Oke committed
1637
        torch.testing.assert_close(focal_loss, ce_loss)
Abhijit Deo's avatar
Abhijit Deo committed
1638

Aditya Oke's avatar
Aditya Oke committed
1639
1640
1641
        focal_loss.backward()
        ce_loss.backward()
        torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad)
Abhijit Deo's avatar
Abhijit Deo committed
1642

Aditya Oke's avatar
Aditya Oke committed
1643
1644
1645
    @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
    @pytest.mark.parametrize("gamma", [0, 2])
    @pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
1646
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1647
1648
1649
1650
1651
1652
1653
1654
1655
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [4, 5])
    def test_jit(self, alpha, gamma, reduction, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        script_fn = torch.jit.script(ops.sigmoid_focal_loss)
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1656
        scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
Aditya Oke's avatar
Aditya Oke committed
1657
1658
1659

        tol = 1e-3 if dtype is torch.half else 1e-5
        torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)
Abhijit Deo's avatar
Abhijit Deo committed
1660

1661
    # Raise ValueError for anonymous reduction mode
1662
    @pytest.mark.parametrize("device", cpu_and_cuda())
1663
1664
1665
1666
1667
1668
1669
1670
1671
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_reduction_mode(self, device, dtype, reduction="xyz"):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        torch.random.manual_seed(0)
        inputs, targets = self._generate_diverse_input_target_pair(device=device, dtype=dtype)
        with pytest.raises(ValueError, match="Invalid"):
            ops.sigmoid_focal_loss(inputs, targets, 0.25, 2, reduction)

Abhijit Deo's avatar
Abhijit Deo committed
1672

1673
1674
class TestMasksToBoxes:
    def test_masks_box(self):
Aditya Oke's avatar
Aditya Oke committed
1675
        def masks_box_check(masks, expected, atol=1e-4):
1676
1677
            out = ops.masks_to_boxes(masks)
            assert out.dtype == torch.float
Aditya Oke's avatar
Aditya Oke committed
1678
            torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=True, atol=atol)
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694

        # Check for int type boxes.
        def _get_image():
            assets_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
            mask_path = os.path.join(assets_directory, "masks.tiff")
            image = Image.open(mask_path)
            return image

        def _create_masks(image, masks):
            for index in range(image.n_frames):
                image.seek(index)
                frame = np.array(image)
                masks[index] = torch.tensor(frame)

            return masks

1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
        expected = torch.tensor(
            [
                [127, 2, 165, 40],
                [2, 50, 44, 92],
                [56, 63, 98, 100],
                [139, 68, 175, 104],
                [160, 112, 198, 145],
                [49, 138, 99, 182],
                [108, 148, 152, 213],
            ],
            dtype=torch.float,
        )
1707
1708
1709
1710
1711
1712
1713
1714

        image = _get_image()
        for dtype in [torch.float16, torch.float32, torch.float64]:
            masks = torch.zeros((image.n_frames, image.height, image.width), dtype=dtype)
            masks = _create_masks(image, masks)
            masks_box_check(masks, expected)


1715
class TestStochasticDepth:
1716
    @pytest.mark.parametrize("seed", range(10))
1717
1718
    @pytest.mark.parametrize("p", [0.2, 0.5, 0.8])
    @pytest.mark.parametrize("mode", ["batch", "row"])
1719
1720
    def test_stochastic_depth_random(self, seed, mode, p):
        torch.manual_seed(seed)
1721
1722
1723
        stats = pytest.importorskip("scipy.stats")
        batch_size = 5
        x = torch.ones(size=(batch_size, 3, 4, 4))
1724
        layer = ops.StochasticDepth(p=p, mode=mode)
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
        layer.__repr__()

        trials = 250
        num_samples = 0
        counts = 0
        for _ in range(trials):
            out = layer(x)
            non_zero_count = out.sum(dim=(1, 2, 3)).nonzero().size(0)
            if mode == "batch":
                if non_zero_count == 0:
                    counts += 1
                num_samples += 1
            elif mode == "row":
                counts += batch_size - non_zero_count
                num_samples += batch_size

1741
        p_value = stats.binomtest(counts, num_samples, p=p).pvalue
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
        assert p_value > 0.01

    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("p", (0, 1))
    @pytest.mark.parametrize("mode", ["batch", "row"])
    def test_stochastic_depth(self, seed, mode, p):
        torch.manual_seed(seed)
        batch_size = 5
        x = torch.ones(size=(batch_size, 3, 4, 4))
        layer = ops.StochasticDepth(p=p, mode=mode)

        out = layer(x)
        if p == 0:
            assert out.equal(x)
        elif p == 1:
            assert out.equal(torch.zeros_like(x))
1758

1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
    def make_obj(self, p, mode, wrap=False):
        obj = ops.StochasticDepth(p, mode)
        return StochasticDepthWrapper(obj) if wrap else obj

    @pytest.mark.parametrize("p", (0, 1))
    @pytest.mark.parametrize("mode", ["batch", "row"])
    def test_is_leaf_node(self, p, mode):
        op_obj = self.make_obj(p, mode, wrap=True)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

1773

1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
class TestUtils:
    @pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
    def test_split_normalization_params(self, norm_layer):
        model = models.mobilenet_v3_large(norm_layer=norm_layer)
        params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])

        assert len(params[0]) == 92
        assert len(params[1]) == 82


1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
class TestDropBlock:
    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("dim", [2, 3])
    @pytest.mark.parametrize("p", [0, 0.5])
    @pytest.mark.parametrize("block_size", [5, 11])
    @pytest.mark.parametrize("inplace", [True, False])
    def test_drop_block(self, seed, dim, p, block_size, inplace):
        torch.manual_seed(seed)
        batch_size = 5
        channels = 3
        height = 11
        width = height
        depth = height
        if dim == 2:
            x = torch.ones(size=(batch_size, channels, height, width))
            layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
            feature_size = height * width
        elif dim == 3:
            x = torch.ones(size=(batch_size, channels, depth, height, width))
            layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
            feature_size = depth * height * width
        layer.__repr__()

        out = layer(x)
        if p == 0:
            assert out.equal(x)
        if block_size == height:
            for b, c in product(range(batch_size), range(channels)):
                assert out[b, c].count_nonzero() in (0, feature_size)

    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("dim", [2, 3])
    @pytest.mark.parametrize("p", [0.1, 0.2])
    @pytest.mark.parametrize("block_size", [3])
    @pytest.mark.parametrize("inplace", [False])
    def test_drop_block_random(self, seed, dim, p, block_size, inplace):
        torch.manual_seed(seed)
        batch_size = 5
        channels = 3
        height = 11
        width = height
        depth = height
        if dim == 2:
            x = torch.ones(size=(batch_size, channels, height, width))
            layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
        elif dim == 3:
            x = torch.ones(size=(batch_size, channels, depth, height, width))
            layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)

        trials = 250
        num_samples = 0
        counts = 0
        cell_numel = torch.tensor(x.shape).prod()
        for _ in range(trials):
            with torch.no_grad():
                out = layer(x)
            non_zero_count = out.nonzero().size(0)
            counts += cell_numel - non_zero_count
            num_samples += cell_numel

        assert abs(p - counts / num_samples) / p < 0.15

    def make_obj(self, dim, p, block_size, inplace, wrap=False):
        if dim == 2:
            obj = ops.DropBlock2d(p, block_size, inplace)
        elif dim == 3:
            obj = ops.DropBlock3d(p, block_size, inplace)
        return DropBlockWrapper(obj) if wrap else obj

    @pytest.mark.parametrize("dim", (2, 3))
    @pytest.mark.parametrize("p", [0, 1])
    @pytest.mark.parametrize("block_size", [5, 7])
    @pytest.mark.parametrize("inplace", [True, False])
    def test_is_leaf_node(self, dim, p, block_size, inplace):
        op_obj = self.make_obj(dim, p, block_size, inplace, wrap=True)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs


1866
if __name__ == "__main__":
1867
    pytest.main([__file__])