test_ops.py 76.2 KB
Newer Older
1
import math
2
import os
3
from abc import ABC, abstractmethod
4
from functools import lru_cache
5
from itertools import product
6
from typing import Callable, List, Tuple
7

8
import numpy as np
9
import pytest
10
import torch
11
import torch.fx
12
import torch.nn.functional as F
13
import torch.testing._internal.optests as optests
14
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
15
from PIL import Image
16
from torch import nn, Tensor
17
from torch.autograd import gradcheck
18
from torch.nn.modules.utils import _pair
19
from torchvision import models, ops
20
21
22
from torchvision.models.feature_extraction import get_graph_node_names


23
24
25
26
27
28
29
30
OPTESTS = [
    "test_schema",
    "test_autograd_registration",
    "test_faketensor",
    "test_aot_dispatch_dynamic",
]


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# Context manager for setting deterministic flag and automatically
# resetting it to its original value
class DeterministicGuard:
    def __init__(self, deterministic, *, warn_only=False):
        self.deterministic = deterministic
        self.warn_only = warn_only

    def __enter__(self):
        self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
        self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
        torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)

    def __exit__(self, exception_type, exception_value, traceback):
        torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore)


47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class RoIOpTesterModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 2

    def forward(self, a, b):
        self.layer(a, b)


class MultiScaleRoIAlignModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 3

    def forward(self, a, b, c):
        self.layer(a, b, c)


class DeformConvModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 3

    def forward(self, a, b, c):
        self.layer(a, b, c)


class StochasticDepthWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 1

    def forward(self, a):
        self.layer(a)
85
86


87
88
89
90
91
92
93
94
95
96
class DropBlockWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 1

    def forward(self, a):
        self.layer(a)


97
98
99
100
101
102
103
104
105
class PoolWrapper(nn.Module):
    def __init__(self, pool: nn.Module):
        super().__init__()
        self.pool = pool

    def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:
        return self.pool(imgs, boxes)


106
107
class RoIOpTester(ABC):
    dtype = torch.float64
108
109
    mps_dtype = torch.float32
    mps_backward_atol = 2e-2
110

111
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
112
    @pytest.mark.parametrize("contiguous", (True, False))
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    @pytest.mark.parametrize(
        "x_dtype",
        (
            torch.float16,
            torch.float32,
            torch.float64,
        ),
        ids=str,
    )
    def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
        if device == "mps" and x_dtype is torch.float64:
            pytest.skip("MPS does not support float64")

        rois_dtype = x_dtype if rois_dtype is None else rois_dtype

        tol = 1e-5
        if x_dtype is torch.half:
            if device == "mps":
                tol = 5e-3
            else:
                tol = 4e-3

135
        pool_size = 5
136
        # n_channels % (pool_size ** 2) == 0 required for PS operations.
137
        n_channels = 2 * (pool_size**2)
138
        x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
139
140
        if not contiguous:
            x = x.permute(0, 1, 3, 2)
141
142
143
144
145
        rois = torch.tensor(
            [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],  # format is (xyxy)
            dtype=rois_dtype,
            device=device,
        )
146

147
        pool_h, pool_w = pool_size, pool_size
148
149
        with DeterministicGuard(deterministic):
            y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
150
        # the following should be true whether we're running an autocast test or not.
151
        assert y.dtype == x.dtype
152
        gt_y = self.expected_fn(
153
            x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs
154
        )
155

156
        torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
157

158
    @pytest.mark.parametrize("device", cpu_and_cuda())
159
160
161
162
163
164
165
166
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

167
    @pytest.mark.parametrize("device", cpu_and_cuda())
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.float):
        op_obj = self.make_obj().to(device=device)
        graph_module = torch.fx.symbolic_trace(op_obj)
        pool_size = 5
        n_channels = 2 * (pool_size**2)
        x = torch.rand(2, n_channels, 5, 5, dtype=x_dtype, device=device)
        rois = torch.tensor(
            [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],  # format is (xyxy)
            dtype=rois_dtype,
            device=device,
        )
        output_gt = op_obj(x, rois)
        assert output_gt.dtype == x.dtype
        output_fx = graph_module(x, rois)
        assert output_fx.dtype == x.dtype
        tol = 1e-5
        torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)

186
    @pytest.mark.parametrize("seed", range(10))
187
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
188
    @pytest.mark.parametrize("contiguous", (True, False))
189
    def test_backward(self, seed, device, contiguous, deterministic=False):
190
191
192
        atol = self.mps_backward_atol if device == "mps" else 1e-05
        dtype = self.mps_dtype if device == "mps" else self.dtype

193
        torch.random.manual_seed(seed)
194
        pool_size = 2
195
        x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True)
196
197
        if not contiguous:
            x = x.permute(0, 1, 3, 2)
198
        rois = torch.tensor(
199
            [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device  # format is (xyxy)
200
        )
201

202
203
        def func(z):
            return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
204

205
        script_func = self.get_script_fn(rois, pool_size)
206

207
        with DeterministicGuard(deterministic):
208
209
210
211
212
213
214
215
216
217
218
            gradcheck(func, (x,), atol=atol)

        gradcheck(script_func, (x,), atol=atol)

    @needs_mps
    def test_mps_error_inputs(self):
        pool_size = 2
        x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True)
        rois = torch.tensor(
            [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps"  # format is (xyxy)
        )
219

220
221
222
223
224
225
226
        def func(z):
            return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)

        with pytest.raises(
            RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs."
        ):
            gradcheck(func, (x,))
227

228
    @needs_cuda
229
230
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
231
232
233
    def test_autocast(self, x_dtype, rois_dtype):
        with torch.cuda.amp.autocast():
            self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
234
235
236

    def _helper_boxes_shape(self, func):
        # test boxes as Tensor[N, 5]
237
        with pytest.raises(AssertionError):
238
239
240
241
242
            a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
            boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
            func(a, boxes, output_size=(2, 2))

        # test boxes as List[Tensor[N, 4]]
243
        with pytest.raises(AssertionError):
244
245
246
247
            a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
            boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
            ops.roi_pool(a, [boxes], output_size=(2, 2))

248
249
250
251
252
253
254
255
    def _helper_jit_boxes_list(self, model):
        x = torch.rand(2, 1, 10, 10)
        roi = torch.tensor([[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], dtype=torch.float).t()
        rois = [roi, roi]
        scriped = torch.jit.script(model)
        y = scriped(x, rois)
        assert y.shape == (10, 1, 3, 3)

256
    @abstractmethod
257
258
    def fn(*args, **kwargs):
        pass
259

260
261
262
263
    @abstractmethod
    def make_obj(*args, **kwargs):
        pass

264
    @abstractmethod
265
266
    def get_script_fn(*args, **kwargs):
        pass
267

268
    @abstractmethod
269
270
    def expected_fn(*args, **kwargs):
        pass
271

272

273
class TestRoiPool(RoIOpTester):
274
275
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
276

277
278
279
280
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
        obj = ops.RoIPool((pool_h, pool_w), spatial_scale)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

281
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
282
283
        scriped = torch.jit.script(ops.roi_pool)
        return lambda x: scriped(x, rois, pool_size)
284

285
286
287
    def expected_fn(
        self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
    ):
288
289
        if device is None:
            device = torch.device("cpu")
290

291
292
        n_channels = x.size(1)
        y = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
293

294
295
        def get_slice(k, block):
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))
296

297
298
299
        for roi_idx, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
300
            roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
301

302
303
304
            roi_h, roi_w = roi_x.shape[-2:]
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w
305

306
307
308
309
310
311
            for i in range(0, pool_h):
                for j in range(0, pool_w):
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
                    if bin_x.numel() > 0:
                        y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0]
        return y
312

313
    def test_boxes_shape(self):
314
315
        self._helper_boxes_shape(ops.roi_pool)

316
317
318
319
    def test_jit_boxes_list(self):
        model = PoolWrapper(ops.RoIPool(output_size=[3, 3], spatial_scale=1.0))
        self._helper_jit_boxes_list(model)

320

321
class TestPSRoIPool(RoIOpTester):
322
323
    mps_backward_atol = 5e-2

324
325
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
326

327
328
329
330
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
        obj = ops.PSRoIPool((pool_h, pool_w), spatial_scale)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

331
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
332
333
        scriped = torch.jit.script(ops.ps_roi_pool)
        return lambda x: scriped(x, rois, pool_size)
334

335
336
337
    def expected_fn(
        self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
    ):
338
339
340
        if device is None:
            device = torch.device("cpu")
        n_input_channels = x.size(1)
341
        assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
342
343
344
345
346
347
348
349
350
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
        y = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)

        def get_slice(k, block):
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))

        for roi_idx, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
351
            roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366

            roi_height = max(i_end - i_begin, 1)
            roi_width = max(j_end - j_begin, 1)
            bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w)

            for i in range(0, pool_h):
                for j in range(0, pool_w):
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
                    if bin_x.numel() > 0:
                        area = bin_x.size(-2) * bin_x.size(-1)
                        for c_out in range(0, n_output_channels):
                            c_in = c_out * (pool_h * pool_w) + pool_w * i + j
                            t = torch.sum(bin_x[c_in, :, :])
                            y[roi_idx, c_out, i, j] = t / area
        return y
367

368
    def test_boxes_shape(self):
369
370
        self._helper_boxes_shape(ops.ps_roi_pool)

371

372
373
def bilinear_interpolate(data, y, x, snap_border=False):
    height, width = data.shape
374

375
376
377
378
379
    if snap_border:
        if -1 < y <= 0:
            y = 0
        elif height - 1 <= y < height:
            y = height - 1
380

381
382
383
384
        if -1 < x <= 0:
            x = 0
        elif width - 1 <= x < width:
            x = width - 1
385

386
387
388
389
    y_low = int(math.floor(y))
    x_low = int(math.floor(x))
    y_high = y_low + 1
    x_high = x_low + 1
390

391
392
    wy_h = y - y_low
    wx_h = x - x_low
393
    wy_l = 1 - wy_h
394
    wx_l = 1 - wx_h
395

396
    val = 0
397
398
399
400
    for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
        for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
            if 0 <= yp < height and 0 <= xp < width:
                val += wx * wy * data[yp, xp]
401
    return val
402
403


404
class TestRoIAlign(RoIOpTester):
405
406
    mps_backward_atol = 6e-2

AhnDW's avatar
AhnDW committed
407
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
408
409
410
        return ops.RoIAlign(
            (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
        )(x, rois)
411

412
413
414
415
416
417
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, aligned=False, wrap=False):
        obj = ops.RoIAlign(
            (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
        )
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

418
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
419
420
        scriped = torch.jit.script(ops.roi_align)
        return lambda x: scriped(x, rois, pool_size)
421

422
423
424
425
426
427
428
429
430
431
432
433
    def expected_fn(
        self,
        in_data,
        rois,
        pool_h,
        pool_w,
        spatial_scale=1,
        sampling_ratio=-1,
        aligned=False,
        device=None,
        dtype=torch.float64,
    ):
434
435
        if device is None:
            device = torch.device("cpu")
436
437
438
        n_channels = in_data.size(1)
        out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)

439
        offset = 0.5 if aligned else 0.0
AhnDW's avatar
AhnDW committed
440

441
442
        for r, roi in enumerate(rois):
            batch_idx = int(roi[0])
AhnDW's avatar
AhnDW committed
443
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - offset for x in roi[1:])
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462

            roi_h = i_end - i_begin
            roi_w = j_end - j_begin
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w

            for i in range(0, pool_h):
                start_h = i_begin + i * bin_h
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
                for j in range(0, pool_w):
                    start_w = j_begin + j * bin_w
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))

                    for channel in range(0, n_channels):
                        val = 0
                        for iy in range(0, grid_h):
                            y = start_h + (iy + 0.5) * bin_h / grid_h
                            for ix in range(0, grid_w):
                                x = start_w + (ix + 0.5) * bin_w / grid_w
463
                                val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True)
464
465
466
                        val /= grid_h * grid_w

                        out_data[r, channel, i, j] = val
467
468
        return out_data

469
    def test_boxes_shape(self):
470
471
        self._helper_boxes_shape(ops.roi_align)

472
    @pytest.mark.parametrize("aligned", (True, False))
473
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
474
    @pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64))  # , ids=str)
475
    @pytest.mark.parametrize("contiguous", (True, False))
476
    @pytest.mark.parametrize("deterministic", (True, False))
477
    def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None):
478
479
        if deterministic and device == "cpu":
            pytest.skip("cpu is always deterministic, don't retest")
480
        super().test_forward(
481
482
483
484
485
486
            device=device,
            contiguous=contiguous,
            deterministic=deterministic,
            x_dtype=x_dtype,
            rois_dtype=rois_dtype,
            aligned=aligned,
487
        )
488

489
    @needs_cuda
490
    @pytest.mark.parametrize("aligned", (True, False))
491
    @pytest.mark.parametrize("deterministic", (True, False))
492
493
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
494
    def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
495
        with torch.cuda.amp.autocast():
496
            self.test_forward(
497
498
499
500
501
502
                torch.device("cuda"),
                contiguous=False,
                deterministic=deterministic,
                aligned=aligned,
                x_dtype=x_dtype,
                rois_dtype=rois_dtype,
503
            )
504

505
    @pytest.mark.parametrize("seed", range(10))
506
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
507
508
509
510
511
512
513
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("deterministic", (True, False))
    def test_backward(self, seed, device, contiguous, deterministic):
        if deterministic and device == "cpu":
            pytest.skip("cpu is always deterministic, don't retest")
        super().test_backward(seed, device, contiguous, deterministic)

514
515
516
517
518
519
    def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
        rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
        rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,))  # set batch index
        rois[:, 3:] += rois[:, 1:3]  # make sure boxes aren't degenerate
        return rois

520
521
522
    @pytest.mark.parametrize("aligned", (True, False))
    @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 10), (0.1, 50)))
    @pytest.mark.parametrize("qdtype", (torch.qint8, torch.quint8, torch.qint32))
523
    def test_qroialign(self, aligned, scale, zero_point, qdtype):
524
525
526
527
528
529
530
        """Make sure quantized version of RoIAlign is close to float version"""
        pool_size = 5
        img_size = 10
        n_channels = 2
        num_imgs = 1
        dtype = torch.float

531
532
533
534
535
536
537
        x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype)
        qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qdtype)

        rois = self._make_rois(img_size, num_imgs, dtype)
        qrois = torch.quantize_per_tensor(rois, scale=scale, zero_point=zero_point, dtype=qdtype)

        x, rois = qx.dequantize(), qrois.dequantize()  # we want to pass the same inputs
538

539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
        y = ops.roi_align(
            x,
            rois,
            output_size=pool_size,
            spatial_scale=1,
            sampling_ratio=-1,
            aligned=aligned,
        )
        qy = ops.roi_align(
            qx,
            qrois,
            output_size=pool_size,
            spatial_scale=1,
            sampling_ratio=-1,
            aligned=aligned,
        )

        # The output qy is itself a quantized tensor and there might have been a loss of info when it was
        # quantized. For a fair comparison we need to quantize y as well
        quantized_float_y = torch.quantize_per_tensor(y, scale=scale, zero_point=zero_point, dtype=qdtype)

        try:
            # Ideally, we would assert this, which passes with (scale, zero) == (1, 0)
            assert (qy == quantized_float_y).all()
        except AssertionError:
            # But because the computation aren't exactly the same between the 2 RoIAlign procedures, some
            # rounding error may lead to a difference of 2 in the output.
            # For example with (scale, zero) = (2, 10), 45.00000... will be quantized to 44
            # but 45.00000001 will be rounded to 46. We make sure below that:
            # - such discrepancies between qy and quantized_float_y are very rare (less then 5%)
            # - any difference between qy and quantized_float_y is == scale
            diff_idx = torch.where(qy != quantized_float_y)
            num_diff = diff_idx[0].numel()
572
            assert num_diff / qy.numel() < 0.05
573
574
575
576
577
578
579

            abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize())
            t_scale = torch.full_like(abs_diff, fill_value=scale)
            torch.testing.assert_close(abs_diff, t_scale, rtol=1e-5, atol=1e-5)

    def test_qroi_align_multiple_images(self):
        dtype = torch.float
580
581
        x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype)
        qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8)
582
        rois = self._make_rois(img_size=10, num_imgs=2, dtype=dtype, num_rois=10)
583
        qrois = torch.quantize_per_tensor(rois, scale=1, zero_point=0, dtype=torch.qint8)
584
585
        with pytest.raises(RuntimeError, match="Only one image per batch is allowed"):
            ops.roi_align(qx, qrois, output_size=5)
586

587
588
589
590
    def test_jit_boxes_list(self):
        model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1))
        self._helper_jit_boxes_list(model)

591

592
class TestPSRoIAlign(RoIOpTester):
593
594
    mps_backward_atol = 5e-2

595
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
596
        return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
597

598
599
600
601
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, wrap=False):
        obj = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

602
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
603
604
        scriped = torch.jit.script(ops.ps_roi_align)
        return lambda x: scriped(x, rois, pool_size)
605

606
607
608
    def expected_fn(
        self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, sampling_ratio=-1, dtype=torch.float64
    ):
609
610
        if device is None:
            device = torch.device("cpu")
611
        n_input_channels = in_data.size(1)
612
        assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
        out_data = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)

        for r, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - 0.5 for x in roi[1:])

            roi_h = i_end - i_begin
            roi_w = j_end - j_begin
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w

            for i in range(0, pool_h):
                start_h = i_begin + i * bin_h
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
                for j in range(0, pool_w):
                    start_w = j_begin + j * bin_w
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
                    for c_out in range(0, n_output_channels):
                        c_in = c_out * (pool_h * pool_w) + pool_w * i + j

                        val = 0
                        for iy in range(0, grid_h):
                            y = start_h + (iy + 0.5) * bin_h / grid_h
                            for ix in range(0, grid_w):
                                x = start_w + (ix + 0.5) * bin_w / grid_w
639
                                val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True)
640
641
642
643
                        val /= grid_h * grid_w

                        out_data[r, c_out, i, j] = val
        return out_data
644

645
    def test_boxes_shape(self):
646
647
        self._helper_boxes_shape(ops.ps_roi_align)

648

649
class TestMultiScaleRoIAlign:
650
651
652
653
654
655
    def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False):
        if fmap_names is None:
            fmap_names = ["0"]
        obj = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)
        return MultiScaleRoIAlignModuleWrapper(obj) if wrap else obj

656
    def test_msroialign_repr(self):
657
        fmap_names = ["0"]
658
659
660
        output_size = (7, 7)
        sampling_ratio = 2
        # Pass mock feature map names
661
        t = self.make_obj(fmap_names, output_size, sampling_ratio, wrap=False)
662
663

        # Check integrity of object __repr__ attribute
664
665
666
667
        expected_string = (
            f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, "
            f"sampling_ratio={sampling_ratio})"
        )
668
        assert repr(t) == expected_string
669

670
    @pytest.mark.parametrize("device", cpu_and_cuda())
671
672
673
674
675
676
677
678
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

679

680
681
class TestNMS:
    def _reference_nms(self, boxes, scores, iou_threshold):
682
683
        """
        Args:
684
685
686
            boxes: boxes in corner-form
            scores: probabilities
            iou_threshold: intersection over union threshold
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
        Returns:
             picked: a list of indexes of the kept boxes
        """
        picked = []
        _, indexes = scores.sort(descending=True)
        while len(indexes) > 0:
            current = indexes[0]
            picked.append(current.item())
            if len(indexes) == 1:
                break
            current_box = boxes[current, :]
            indexes = indexes[1:]
            rest_boxes = boxes[indexes, :]
            iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1)
            indexes = indexes[iou <= iou_threshold]

        return torch.as_tensor(picked)

705
706
707
708
709
    def _create_tensors_with_iou(self, N, iou_thresh):
        # force last box to have a pre-defined iou with the first box
        # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
        # then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
        # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
710
711
712
        # Adjust the threshold upward a bit with the intent of creating
        # at least one box that exceeds (barely) the threshold and so
        # should be suppressed.
713
        boxes = torch.rand(N, 4) * 100
714
715
716
        boxes[:, 2:] += boxes[:, :2]
        boxes[-1, :] = boxes[0, :]
        x0, y0, x1, y1 = boxes[-1].tolist()
717
        iou_thresh += 1e-5
718
        boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
719
720
721
        scores = torch.rand(N)
        return boxes, scores

722
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
723
    @pytest.mark.parametrize("seed", range(10))
724
    @pytest.mark.opcheck_only_one()
725
726
    def test_nms_ref(self, iou, seed):
        torch.random.manual_seed(seed)
727
        err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
728
729
730
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        keep_ref = self._reference_nms(boxes, scores, iou)
        keep = ops.nms(boxes, scores, iou)
731
        torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou))
732
733
734
735
736
737
738
739
740
741
742

    def test_nms_input_errors(self):
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(4), torch.rand(3), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 5), torch.rand(3), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 4), torch.rand(4), 0.5)

743
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
744
    @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10)))
745
    @pytest.mark.opcheck_only_one()
746
    def test_qnms(self, iou, scale, zero_point):
747
        # Note: we compare qnms vs nms instead of qnms vs reference implementation.
748
        # This is because with the int conversion, the trick used in _create_tensors_with_iou
749
        # doesn't really work (in fact, nms vs reference implem will also fail with ints)
750
        err_msg = "NMS and QNMS give different results for IoU={}"
751
        boxes, scores = self._create_tensors_with_iou(1000, iou)
752
        scores *= 100  # otherwise most scores would be 0 or 1 after int conversion
753

754
755
        qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point, dtype=torch.quint8)
        qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point, dtype=torch.quint8)
756

757
758
        boxes = qboxes.dequantize()
        scores = qscores.dequantize()
759

760
761
        keep = ops.nms(boxes, scores, iou)
        qkeep = ops.nms(qboxes, qscores, iou)
762

763
        torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
764

765
766
767
768
769
770
771
    @pytest.mark.parametrize(
        "device",
        (
            pytest.param("cuda", marks=pytest.mark.needs_cuda),
            pytest.param("mps", marks=pytest.mark.needs_mps),
        ),
    )
772
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
773
    @pytest.mark.opcheck_only_one()
774
775
    def test_nms_gpu(self, iou, device, dtype=torch.float64):
        dtype = torch.float32 if device == "mps" else dtype
776
        tol = 1e-3 if dtype is torch.half else 1e-5
777
        err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
778

779
780
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        r_cpu = ops.nms(boxes, scores, iou)
781
        r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)
782

783
        is_eq = torch.allclose(r_cpu, r_gpu.cpu())
784
785
786
        if not is_eq:
            # if the indices are not the same, ensure that it's because the scores
            # are duplicate
787
            is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
788
789
790
        assert is_eq, err_msg.format(iou)

    @needs_cuda
791
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
792
    @pytest.mark.parametrize("dtype", (torch.float, torch.half))
793
    @pytest.mark.opcheck_only_one()
794
795
    def test_autocast(self, iou, dtype):
        with torch.cuda.amp.autocast():
796
797
798
799
800
801
802
803
804
            self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")

    @pytest.mark.parametrize(
        "device",
        (
            pytest.param("cuda", marks=pytest.mark.needs_cuda),
            pytest.param("mps", marks=pytest.mark.needs_mps),
        ),
    )
805
    @pytest.mark.opcheck_only_one()
806
    def test_nms_float16(self, device):
807
808
809
810
811
812
        boxes = torch.tensor(
            [
                [285.3538, 185.5758, 1193.5110, 851.4551],
                [285.1472, 188.7374, 1192.4984, 851.0669],
                [279.2440, 197.9812, 1189.4746, 849.2019],
            ]
813
814
        ).to(device)
        scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device)
815
816
817
818

        iou_thres = 0.2
        keep32 = ops.nms(boxes, scores, iou_thres)
        keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
819
        assert_equal(keep32, keep16)
820

821
    @pytest.mark.parametrize("seed", range(10))
822
    @pytest.mark.opcheck_only_one()
823
    def test_batched_nms_implementations(self, seed):
824
        """Make sure that both implementations of batched_nms yield identical results"""
825
        torch.random.manual_seed(seed)
826
827

        num_boxes = 1000
828
        iou_threshold = 0.9
829
830
831
832
833
834
835
836
837
838

        boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
        assert max(boxes[:, 0]) < min(boxes[:, 2])  # x1 < x2
        assert max(boxes[:, 1]) < min(boxes[:, 3])  # y1 < y2

        scores = torch.rand(num_boxes)
        idxs = torch.randint(0, 4, size=(num_boxes,))
        keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
        keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)

839
840
841
        torch.testing.assert_close(
            keep_vanilla, keep_trick, msg="The vanilla and the trick implementation yield different nms outputs."
        )
842
843
844

        # Also make sure an empty tensor is returned if boxes is empty
        empty = torch.empty((0,), dtype=torch.int64)
845
        torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None))
846

847

848
849
850
851
852
853
854
855
856
optests.generate_opcheck_tests(
    testcase=TestNMS,
    namespaces=["torchvision"],
    failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
    additional_decorators=[],
    test_utils=OPTESTS,
)


857
858
859
class TestDeformConv:
    dtype = torch.float64

860
    def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
        stride_h, stride_w = _pair(stride)
        pad_h, pad_w = _pair(padding)
        dil_h, dil_w = _pair(dilation)
        weight_h, weight_w = weight.shape[-2:]

        n_batches, n_in_channels, in_h, in_w = x.shape
        n_out_channels = weight.shape[0]

        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

        n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w)
        in_c_per_offset_grp = n_in_channels // n_offset_grps

        n_weight_grps = n_in_channels // weight.shape[1]
        in_c_per_weight_grp = weight.shape[1]
        out_c_per_weight_grp = n_out_channels // n_weight_grps

        out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
        for b in range(n_batches):
            for c_out in range(n_out_channels):
                for i in range(out_h):
                    for j in range(out_w):
                        for di in range(weight_h):
                            for dj in range(weight_w):
                                for c in range(in_c_per_weight_grp):
                                    weight_grp = c_out // out_c_per_weight_grp
                                    c_in = weight_grp * in_c_per_weight_grp + c

                                    offset_grp = c_in // in_c_per_offset_grp
891
892
                                    mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj
                                    offset_idx = 2 * mask_idx
893
894
895
896

                                    pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
                                    pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]

897
898
899
900
                                    mask_value = 1.0
                                    if mask is not None:
                                        mask_value = mask[b, mask_idx, i, j]

901
902
903
904
905
                                    out[b, c_out, i, j] += (
                                        mask_value
                                        * weight[c_out, c, di, dj]
                                        * bilinear_interpolate(x[b, c_in, :, :], pi, pj)
                                    )
906
907
908
        out += bias.view(1, n_out_channels, 1, 1)
        return out

909
    @lru_cache(maxsize=None)
910
    def get_fn_args(self, device, contiguous, batch_sz, dtype):
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
        n_in_channels = 6
        n_out_channels = 2
        n_weight_grps = 2
        n_offset_grps = 3

        stride = (2, 1)
        pad = (1, 0)
        dilation = (2, 1)

        stride_h, stride_w = stride
        pad_h, pad_w = pad
        dil_h, dil_w = dilation
        weight_h, weight_w = (3, 2)
        in_h, in_w = (5, 4)

        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

929
        x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True)
930

931
932
933
934
935
936
937
938
939
        offset = torch.randn(
            batch_sz,
            n_offset_grps * 2 * weight_h * weight_w,
            out_h,
            out_w,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
940

941
942
943
        mask = torch.randn(
            batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, device=device, dtype=dtype, requires_grad=True
        )
944

945
946
947
948
949
950
951
952
953
        weight = torch.randn(
            n_out_channels,
            n_in_channels // n_weight_grps,
            weight_h,
            weight_w,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
954

955
        bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True)
956
957
958
959

        if not contiguous:
            x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
            offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
960
            mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
961
962
            weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)

963
        return x, weight, offset, mask, bias, stride, pad, dilation
964

965
966
967
968
969
970
    def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, wrap=False):
        obj = ops.DeformConv2d(
            in_channels, out_channels, kernel_size, stride=(2, 1), padding=(1, 0), dilation=(2, 1), groups=groups
        )
        return DeformConvModuleWrapper(obj) if wrap else obj

971
    @pytest.mark.parametrize("device", cpu_and_cuda())
972
973
974
975
976
977
978
979
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

980
    @pytest.mark.parametrize("device", cpu_and_cuda())
981
982
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("batch_sz", (0, 33))
983
984
    def test_forward(self, device, contiguous, batch_sz, dtype=None):
        dtype = dtype or self.dtype
985
        x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
986
987
988
989
        in_channels = 6
        out_channels = 2
        kernel_size = (3, 2)
        groups = 2
Nicolas Hug's avatar
Nicolas Hug committed
990
        tol = 2e-3 if dtype is torch.half else 1e-5
991

992
993
994
        layer = self.make_obj(in_channels, out_channels, kernel_size, groups, wrap=False).to(
            device=x.device, dtype=dtype
        )
995
        res = layer(x, offset, mask)
996
997
998

        weight = layer.weight.data
        bias = layer.bias.data
999
1000
        expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)

1001
        torch.testing.assert_close(
1002
            res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1003
        )
1004
1005
1006
1007

        # no modulation test
        res = layer(x, offset)
        expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
1008

1009
        torch.testing.assert_close(
1010
            res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1011
        )
1012

1013
1014
1015
1016
1017
    def test_wrong_sizes(self):
        in_channels = 6
        out_channels = 2
        kernel_size = (3, 2)
        groups = 2
1018
1019
1020
1021
1022
1023
        x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(
            "cpu", contiguous=True, batch_sz=10, dtype=self.dtype
        )
        layer = ops.DeformConv2d(
            in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups
        )
1024
        with pytest.raises(RuntimeError, match="the shape of the offset"):
1025
            wrong_offset = torch.rand_like(offset[:, :2])
1026
            layer(x, wrong_offset)
1027

1028
        with pytest.raises(RuntimeError, match=r"mask.shape\[1\] is not valid"):
1029
            wrong_mask = torch.rand_like(mask[:, :2])
1030
            layer(x, offset, wrong_mask)
1031

1032
    @pytest.mark.parametrize("device", cpu_and_cuda())
1033
1034
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("batch_sz", (0, 33))
1035
    def test_backward(self, device, contiguous, batch_sz):
1036
1037
1038
        x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
            device, contiguous, batch_sz, self.dtype
        )
1039
1040

        def func(x_, offset_, mask_, weight_, bias_):
1041
1042
1043
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_
            )
1044

1045
        gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True)
1046
1047

        def func_no_mask(x_, offset_, weight_, bias_):
1048
1049
1050
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None
            )
1051

1052
        gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True)
1053
1054
1055
1056

        @torch.jit.script
        def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
            # type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=mask_
            )

        gradcheck(
            lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
            (x, offset, mask, weight, bias),
            nondet_tol=1e-5,
            fast_mode=True,
        )
1067
1068

        @torch.jit.script
1069
1070
        def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
            # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=None
            )

        gradcheck(
            lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
            (x, offset, weight, bias),
            nondet_tol=1e-5,
            fast_mode=True,
        )
1081

1082
    @needs_cuda
1083
    @pytest.mark.parametrize("contiguous", (True, False))
1084
    def test_compare_cpu_cuda_grads(self, contiguous):
1085
1086
1087
        # Test from https://github.com/pytorch/vision/issues/2598
        # Run on CUDA only

1088
1089
        # compare grads computed on CUDA with grads computed on CPU
        true_cpu_grads = None
1090

1091
1092
1093
1094
        init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
        img = torch.randn(8, 9, 1000, 110)
        offset = torch.rand(8, 2 * 3 * 3, 1000, 110)
        mask = torch.rand(8, 3 * 3, 1000, 110)
1095

1096
1097
1098
1099
1100
1101
1102
        if not contiguous:
            img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
            offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
            mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
            weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
        else:
            weight = init_weight
1103

1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
        for d in ["cpu", "cuda"]:
            out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
            out.mean().backward()
            if true_cpu_grads is None:
                true_cpu_grads = init_weight.grad
                assert true_cpu_grads is not None
            else:
                assert init_weight.grad is not None
                res_grads = init_weight.grad.to("cpu")
                torch.testing.assert_close(true_cpu_grads, res_grads)

    @needs_cuda
1116
1117
    @pytest.mark.parametrize("batch_sz", (0, 33))
    @pytest.mark.parametrize("dtype", (torch.float, torch.half))
1118
1119
1120
1121
    def test_autocast(self, batch_sz, dtype):
        with torch.cuda.amp.autocast():
            self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)

1122
1123
1124
1125
    def test_forward_scriptability(self):
        # Non-regression test for https://github.com/pytorch/vision/issues/4078
        torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))

1126
1127

class TestFrozenBNT:
1128
1129
    def test_frozenbatchnorm2d_repr(self):
        num_features = 32
1130
1131
        eps = 1e-5
        t = ops.misc.FrozenBatchNorm2d(num_features, eps=eps)
1132
1133

        # Check integrity of object __repr__ attribute
1134
        expected_string = f"FrozenBatchNorm2d({num_features}, eps={eps})"
1135
        assert repr(t) == expected_string
1136

1137
1138
1139
    @pytest.mark.parametrize("seed", range(10))
    def test_frozenbatchnorm2d_eps(self, seed):
        torch.random.manual_seed(seed)
1140
1141
        sample_size = (4, 32, 28, 28)
        x = torch.rand(sample_size)
1142
1143
1144
1145
1146
1147
1148
        state_dict = dict(
            weight=torch.rand(sample_size[1]),
            bias=torch.rand(sample_size[1]),
            running_mean=torch.rand(sample_size[1]),
            running_var=torch.rand(sample_size[1]),
            num_batches_tracked=torch.tensor(100),
        )
1149

1150
        # Check that default eps is equal to the one of BN
1151
1152
        fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
        fbn.load_state_dict(state_dict, strict=False)
1153
        bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
1154
1155
        bn.load_state_dict(state_dict)
        # Difference is expected to fall in an acceptable range
1156
        torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
1157
1158
1159
1160
1161
1162

        # Check computation for eps > 0
        fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
        fbn.load_state_dict(state_dict, strict=False)
        bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
        bn.load_state_dict(state_dict)
1163
        torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
1164

1165

Aditya Oke's avatar
Aditya Oke committed
1166
class TestBoxConversionToRoi:
1167
1168
1169
    def _get_box_sequences():
        # Define here the argument type of `boxes` supported by region pooling operations
        box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float)
1170
1171
1172
1173
        box_list = [
            torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
            torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
        ]
1174
1175
1176
        box_tuple = tuple(box_list)
        return box_tensor, box_list, box_tuple

1177
    @pytest.mark.parametrize("box_sequence", _get_box_sequences())
1178
    def test_check_roi_boxes_shape(self, box_sequence):
1179
        # Ensure common sequences of tensors are supported
1180
        ops._utils.check_roi_boxes_shape(box_sequence)
1181

1182
    @pytest.mark.parametrize("box_sequence", _get_box_sequences())
1183
    def test_convert_boxes_to_roi_format(self, box_sequence):
1184
1185
        # Ensure common sequences of tensors yield the same result
        ref_tensor = None
1186
1187
1188
1189
        if ref_tensor is None:
            ref_tensor = box_sequence
        else:
            assert_equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence))
1190
1191


Aditya Oke's avatar
Aditya Oke committed
1192
class TestBoxConvert:
1193
    def test_bbox_same(self):
1194
1195
1196
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
1197

1198
        exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
1199

1200
1201
1202
1203
        assert exp_xyxy.size() == torch.Size([4, 4])
        assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy)
        assert_equal(ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh"), exp_xyxy)
        assert_equal(ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh"), exp_xyxy)
1204
1205
1206
1207

    def test_bbox_xyxy_xywh(self):
        # Simple test convert boxes to xywh and back. Make sure they are same.
        # box_tensor is in x1 y1 x2 y2 format.
1208
1209
1210
1211
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
        exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
1212

1213
        assert exp_xywh.size() == torch.Size([4, 4])
1214
        box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
1215
        assert_equal(box_xywh, exp_xywh)
1216
1217
1218

        # Reverse conversion
        box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy")
1219
        assert_equal(box_xyxy, box_tensor)
1220
1221

    def test_bbox_xyxy_cxcywh(self):
Aditya Oke's avatar
Aditya Oke committed
1222
        # Simple test convert boxes to cxcywh and back. Make sure they are same.
1223
        # box_tensor is in x1 y1 x2 y2 format.
1224
1225
1226
1227
1228
1229
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
        exp_cxcywh = torch.tensor(
            [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
        )
1230

1231
        assert exp_cxcywh.size() == torch.Size([4, 4])
1232
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
1233
        assert_equal(box_cxcywh, exp_cxcywh)
1234
1235
1236

        # Reverse conversion
        box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy")
1237
        assert_equal(box_xyxy, box_tensor)
1238
1239

    def test_bbox_xywh_cxcywh(self):
1240
1241
1242
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
        )
1243

1244
1245
1246
        exp_cxcywh = torch.tensor(
            [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
        )
1247

1248
        assert exp_cxcywh.size() == torch.Size([4, 4])
1249
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh")
1250
        assert_equal(box_cxcywh, exp_cxcywh)
1251
1252
1253

        # Reverse conversion
        box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
1254
        assert_equal(box_xywh, box_tensor)
1255

1256
1257
    @pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh"])
    @pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy"])
1258
    def test_bbox_invalid(self, inv_infmt, inv_outfmt):
1259
1260
1261
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
        )
1262

1263
1264
        with pytest.raises(ValueError):
            ops.box_convert(box_tensor, inv_infmt, inv_outfmt)
1265
1266

    def test_bbox_convert_jit(self):
1267
1268
1269
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
1270

1271
        scripted_fn = torch.jit.script(ops.box_convert)
1272

1273
        box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
1274
        scripted_xywh = scripted_fn(box_tensor, "xyxy", "xywh")
Aditya Oke's avatar
Aditya Oke committed
1275
        torch.testing.assert_close(scripted_xywh, box_xywh)
1276

1277
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
1278
        scripted_cxcywh = scripted_fn(box_tensor, "xyxy", "cxcywh")
Aditya Oke's avatar
Aditya Oke committed
1279
        torch.testing.assert_close(scripted_cxcywh, box_cxcywh)
1280
1281


Aditya Oke's avatar
Aditya Oke committed
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
class TestBoxArea:
    def area_check(self, box, expected, atol=1e-4):
        out = ops.box_area(box)
        torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)

    @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
    def test_int_boxes(self, dtype):
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
        expected = torch.tensor([10000, 0], dtype=torch.int32)
        self.area_check(box_tensor, expected)

    @pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
    def test_float_boxes(self, dtype):
        box_tensor = torch.tensor(FLOAT_BOXES, dtype=dtype)
        expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
        self.area_check(box_tensor, expected)

    def test_float16_box(self):
        box_tensor = torch.tensor(
            [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16
        )

        expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
        self.area_check(box_tensor, expected, atol=0.01)
1306

Aditya Oke's avatar
Aditya Oke committed
1307
1308
1309
1310
1311
1312
    def test_box_area_jit(self):
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
        expected = ops.box_area(box_tensor)
        scripted_fn = torch.jit.script(ops.box_area)
        scripted_area = scripted_fn(box_tensor)
        torch.testing.assert_close(scripted_area, expected)
1313

Aditya Oke's avatar
Aditya Oke committed
1314

1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
FLOAT_BOXES = [
    [285.3538, 185.5758, 1193.5110, 851.4551],
    [285.1472, 188.7374, 1192.4984, 851.0669],
    [279.2440, 197.9812, 1189.4746, 849.2019],
]


def gen_box(size, dtype=torch.float):
    xy1 = torch.rand((size, 2), dtype=dtype)
    xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
    return torch.cat([xy1, xy2], axis=-1)


Aditya Oke's avatar
Aditya Oke committed
1330
1331
class TestIouBase:
    @staticmethod
1332
    def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
1333
        for dtype in dtypes:
1334
1335
            actual_box1 = torch.tensor(actual_box1, dtype=dtype)
            actual_box2 = torch.tensor(actual_box2, dtype=dtype)
1336
            expected_box = torch.tensor(expected)
1337
            out = target_fn(actual_box1, actual_box2)
Aditya Oke's avatar
Aditya Oke committed
1338
            torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
Aditya Oke's avatar
Aditya Oke committed
1339

Aditya Oke's avatar
Aditya Oke committed
1340
    @staticmethod
1341
1342
    def _run_jit_test(target_fn: Callable, actual_box: List):
        box_tensor = torch.tensor(actual_box, dtype=torch.float)
Aditya Oke's avatar
Aditya Oke committed
1343
1344
1345
1346
        expected = target_fn(box_tensor, box_tensor)
        scripted_fn = torch.jit.script(target_fn)
        scripted_out = scripted_fn(box_tensor, box_tensor)
        torch.testing.assert_close(scripted_out, expected)
Aditya Oke's avatar
Aditya Oke committed
1347

1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
    @staticmethod
    def _cartesian_product(boxes1, boxes2, target_fn: Callable):
        N = boxes1.size(0)
        M = boxes2.size(0)
        result = torch.zeros((N, M))
        for i in range(N):
            for j in range(M):
                result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
        return result

    @staticmethod
    def _run_cartesian_test(target_fn: Callable):
        boxes1 = gen_box(5)
        boxes2 = gen_box(7)
        a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
        b = target_fn(boxes1, boxes2)
1364
        torch.testing.assert_close(a, b)
1365

1366

Aditya Oke's avatar
Aditya Oke committed
1367
class TestBoxIou(TestIouBase):
1368
    int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
Aditya Oke's avatar
Aditya Oke committed
1369
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
Aditya Oke's avatar
Aditya Oke committed
1370

Aditya Oke's avatar
Aditya Oke committed
1371
    @pytest.mark.parametrize(
1372
        "actual_box1, actual_box2, dtypes, atol, expected",
Aditya Oke's avatar
Aditya Oke committed
1373
        [
1374
1375
1376
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
Aditya Oke's avatar
Aditya Oke committed
1377
1378
        ],
    )
1379
1380
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)
Aditya Oke's avatar
Aditya Oke committed
1381

Aditya Oke's avatar
Aditya Oke committed
1382
1383
    def test_iou_jit(self):
        self._run_jit_test(ops.box_iou, INT_BOXES)
Aditya Oke's avatar
Aditya Oke committed
1384

1385
1386
1387
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.box_iou)

1388

Aditya Oke's avatar
Aditya Oke committed
1389
class TestGeneralizedBoxIou(TestIouBase):
1390
    int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
Aditya Oke's avatar
Aditya Oke committed
1391
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1392
1393

    @pytest.mark.parametrize(
1394
        "actual_box1, actual_box2, dtypes, atol, expected",
1395
        [
1396
1397
1398
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1399
1400
        ],
    )
1401
1402
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.generalized_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1403

Aditya Oke's avatar
Aditya Oke committed
1404
1405
    def test_iou_jit(self):
        self._run_jit_test(ops.generalized_box_iou, INT_BOXES)
1406

1407
1408
1409
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.generalized_box_iou)

1410

Aditya Oke's avatar
Aditya Oke committed
1411
class TestDistanceBoxIoU(TestIouBase):
1412
1413
1414
1415
1416
1417
    int_expected = [
        [1.0000, 0.1875, -0.4444],
        [0.1875, 1.0000, -0.5625],
        [-0.4444, -0.5625, 1.0000],
        [-0.0781, 0.1875, -0.6267],
    ]
Aditya Oke's avatar
Aditya Oke committed
1418
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1419

Aditya Oke's avatar
Aditya Oke committed
1420
    @pytest.mark.parametrize(
1421
        "actual_box1, actual_box2, dtypes, atol, expected",
Aditya Oke's avatar
Aditya Oke committed
1422
        [
1423
1424
1425
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
Aditya Oke's avatar
Aditya Oke committed
1426
1427
        ],
    )
1428
1429
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.distance_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1430

Aditya Oke's avatar
Aditya Oke committed
1431
1432
    def test_iou_jit(self):
        self._run_jit_test(ops.distance_box_iou, INT_BOXES)
1433

1434
1435
1436
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.distance_box_iou)

1437

Aditya Oke's avatar
Aditya Oke committed
1438
class TestCompleteBoxIou(TestIouBase):
1439
1440
1441
1442
1443
1444
    int_expected = [
        [1.0000, 0.1875, -0.4444],
        [0.1875, 1.0000, -0.5625],
        [-0.4444, -0.5625, 1.0000],
        [-0.0781, 0.1875, -0.6267],
    ]
Aditya Oke's avatar
Aditya Oke committed
1445
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1446
1447

    @pytest.mark.parametrize(
1448
        "actual_box1, actual_box2, dtypes, atol, expected",
1449
        [
1450
1451
1452
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1453
1454
        ],
    )
1455
1456
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.complete_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1457

Aditya Oke's avatar
Aditya Oke committed
1458
1459
    def test_iou_jit(self):
        self._run_jit_test(ops.complete_box_iou, INT_BOXES)
1460

1461
1462
1463
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.complete_box_iou)

1464

Aditya Oke's avatar
Aditya Oke committed
1465
1466
1467
1468
1469
def get_boxes(dtype, device):
    box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
    box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
    box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
    box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)
1470

Aditya Oke's avatar
Aditya Oke committed
1471
1472
    box1s = torch.stack([box2, box2], dim=0)
    box2s = torch.stack([box3, box4], dim=0)
1473

Aditya Oke's avatar
Aditya Oke committed
1474
    return box1, box2, box3, box4, box1s, box2s
1475

Aditya Oke's avatar
Aditya Oke committed
1476

Aditya Oke's avatar
Aditya Oke committed
1477
1478
1479
1480
def assert_iou_loss(iou_fn, box1, box2, expected_loss, device, reduction="none"):
    computed_loss = iou_fn(box1, box2, reduction=reduction)
    expected_loss = torch.tensor(expected_loss, device=device)
    torch.testing.assert_close(computed_loss, expected_loss)
1481
1482


Aditya Oke's avatar
Aditya Oke committed
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
def assert_empty_loss(iou_fn, dtype, device):
    box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
    box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
    loss = iou_fn(box1, box2, reduction="mean")
    loss.backward()
    torch.testing.assert_close(loss, torch.tensor(0.0, device=device))
    assert box1.grad is not None, "box1.grad should not be None after backward is called"
    assert box2.grad is not None, "box2.grad should not be None after backward is called"
    loss = iou_fn(box1, box2, reduction="none")
    assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty"
Aditya Oke's avatar
Aditya Oke committed
1493

Aditya Oke's avatar
Aditya Oke committed
1494

Aditya Oke's avatar
Aditya Oke committed
1495
1496
class TestGeneralizedBoxIouLoss:
    # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py
1497
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1498
1499
1500
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_giou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1501

Aditya Oke's avatar
Aditya Oke committed
1502
1503
        # Identical boxes should have loss of 0
        assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1504

Aditya Oke's avatar
Aditya Oke committed
1505
1506
        # quarter size box inside other box = IoU of 0.25
        assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1507

Aditya Oke's avatar
Aditya Oke committed
1508
1509
1510
        # Two side by side boxes, area=union
        # IoU=0 and GIoU=0 (loss 1.0)
        assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1511

Aditya Oke's avatar
Aditya Oke committed
1512
1513
1514
        # Two diagonally adjacent boxes, area=2*union
        # IoU=0 and GIoU=-0.5 (loss 1.5)
        assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1515

Aditya Oke's avatar
Aditya Oke committed
1516
1517
1518
        # Test batched loss and reductions
        assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, device=device, reduction="sum")
        assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, device=device, reduction="mean")
Yassine Alouini's avatar
Yassine Alouini committed
1519

1520
1521
1522
1523
1524
        # Test reduction value
        # reduction value other than ["none", "mean", "sum"] should raise a ValueError
        with pytest.raises(ValueError, match="Invalid"):
            ops.generalized_box_iou_loss(box1s, box2s, reduction="xyz")

1525
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1526
1527
1528
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_inputs(self, dtype, device):
        assert_empty_loss(ops.generalized_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1529
1530


Aditya Oke's avatar
Aditya Oke committed
1531
1532
class TestCompleteBoxIouLoss:
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1533
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1534
1535
    def test_ciou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1536

Aditya Oke's avatar
Aditya Oke committed
1537
1538
1539
1540
1541
1542
        assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
        assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
Yassine Alouini's avatar
Yassine Alouini committed
1543

1544
1545
1546
        with pytest.raises(ValueError, match="Invalid"):
            ops.complete_box_iou_loss(box1s, box2s, reduction="xyz")

1547
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1548
1549
1550
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_inputs(self, dtype, device):
        assert_empty_loss(ops.complete_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1551
1552


Aditya Oke's avatar
Aditya Oke committed
1553
class TestDistanceBoxIouLoss:
1554
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1555
1556
1557
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_distance_iou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1558

Aditya Oke's avatar
Aditya Oke committed
1559
1560
1561
1562
1563
1564
        assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
        assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
Yassine Alouini's avatar
Yassine Alouini committed
1565

1566
1567
1568
        with pytest.raises(ValueError, match="Invalid"):
            ops.distance_box_iou_loss(box1s, box2s, reduction="xyz")

1569
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1570
1571
1572
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_distance_iou_inputs(self, dtype, device):
        assert_empty_loss(ops.distance_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1573
1574


Aditya Oke's avatar
Aditya Oke committed
1575
1576
1577
1578
class TestFocalLoss:
    def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs):
        def logit(p):
            return torch.log(p / (1 - p))
Yassine Alouini's avatar
Yassine Alouini committed
1579

Aditya Oke's avatar
Aditya Oke committed
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
        def generate_tensor_with_range_type(shape, range_type, **kwargs):
            if range_type != "random_binary":
                low, high = {
                    "small": (0.0, 0.2),
                    "big": (0.8, 1.0),
                    "zeros": (0.0, 0.0),
                    "ones": (1.0, 1.0),
                    "random": (0.0, 1.0),
                }[range_type]
                return torch.testing.make_tensor(shape, low=low, high=high, **kwargs)
            else:
                return torch.randint(0, 2, shape, **kwargs)
Yassine Alouini's avatar
Yassine Alouini committed
1592

Aditya Oke's avatar
Aditya Oke committed
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
        # This function will return inputs and targets with shape: (shape[0]*9, shape[1])
        inputs = []
        targets = []
        for input_range_type, target_range_type in [
            ("small", "zeros"),
            ("small", "ones"),
            ("small", "random_binary"),
            ("big", "zeros"),
            ("big", "ones"),
            ("big", "random_binary"),
            ("random", "zeros"),
            ("random", "ones"),
            ("random", "random_binary"),
        ]:
            inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs)))
            targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs))
Yassine Alouini's avatar
Yassine Alouini committed
1609

Aditya Oke's avatar
Aditya Oke committed
1610
        return torch.cat(inputs), torch.cat(targets)
Yassine Alouini's avatar
Yassine Alouini committed
1611

Aditya Oke's avatar
Aditya Oke committed
1612
1613
    @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
    @pytest.mark.parametrize("gamma", [0, 2])
1614
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [0, 1])
    def test_correct_ratio(self, alpha, gamma, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        # For testing the ratio with manual calculation, we require the reduction to be "none"
        reduction = "none"
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)
Yassine Alouini's avatar
Yassine Alouini committed
1626

Aditya Oke's avatar
Aditya Oke committed
1627
1628
1629
        assert torch.all(
            focal_loss <= ce_loss
        ), "focal loss must be less or equal to cross entropy loss with same input"
Abhijit Deo's avatar
Abhijit Deo committed
1630

Aditya Oke's avatar
Aditya Oke committed
1631
1632
1633
1634
1635
1636
1637
        loss_ratio = (focal_loss / ce_loss).squeeze()
        prob = torch.sigmoid(inputs)
        p_t = prob * targets + (1 - prob) * (1 - targets)
        correct_ratio = (1.0 - p_t) ** gamma
        if alpha >= 0:
            alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
            correct_ratio = correct_ratio * alpha_t
Abhijit Deo's avatar
Abhijit Deo committed
1638

Aditya Oke's avatar
Aditya Oke committed
1639
1640
        tol = 1e-3 if dtype is torch.half else 1e-5
        torch.testing.assert_close(correct_ratio, loss_ratio, atol=tol, rtol=tol)
Abhijit Deo's avatar
Abhijit Deo committed
1641

Aditya Oke's avatar
Aditya Oke committed
1642
    @pytest.mark.parametrize("reduction", ["mean", "sum"])
1643
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [2, 3])
    def test_equal_ce_loss(self, reduction, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        # focal loss should be equal ce_loss if alpha=-1 and gamma=0
        alpha = -1
        gamma = 0
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        inputs_fl = inputs.clone().requires_grad_()
        targets_fl = targets.clone()
        inputs_ce = inputs.clone().requires_grad_()
        targets_ce = targets.clone()
        focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction)
        ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction)
Abhijit Deo's avatar
Abhijit Deo committed
1660

Aditya Oke's avatar
Aditya Oke committed
1661
        torch.testing.assert_close(focal_loss, ce_loss)
Abhijit Deo's avatar
Abhijit Deo committed
1662

Aditya Oke's avatar
Aditya Oke committed
1663
1664
1665
        focal_loss.backward()
        ce_loss.backward()
        torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad)
Abhijit Deo's avatar
Abhijit Deo committed
1666

Aditya Oke's avatar
Aditya Oke committed
1667
1668
1669
    @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
    @pytest.mark.parametrize("gamma", [0, 2])
    @pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
1670
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1671
1672
1673
1674
1675
1676
1677
1678
1679
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [4, 5])
    def test_jit(self, alpha, gamma, reduction, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        script_fn = torch.jit.script(ops.sigmoid_focal_loss)
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1680
        scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
Aditya Oke's avatar
Aditya Oke committed
1681
1682
1683

        tol = 1e-3 if dtype is torch.half else 1e-5
        torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)
Abhijit Deo's avatar
Abhijit Deo committed
1684

1685
    # Raise ValueError for anonymous reduction mode
1686
    @pytest.mark.parametrize("device", cpu_and_cuda())
1687
1688
1689
1690
1691
1692
1693
1694
1695
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_reduction_mode(self, device, dtype, reduction="xyz"):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        torch.random.manual_seed(0)
        inputs, targets = self._generate_diverse_input_target_pair(device=device, dtype=dtype)
        with pytest.raises(ValueError, match="Invalid"):
            ops.sigmoid_focal_loss(inputs, targets, 0.25, 2, reduction)

Abhijit Deo's avatar
Abhijit Deo committed
1696

1697
1698
class TestMasksToBoxes:
    def test_masks_box(self):
Aditya Oke's avatar
Aditya Oke committed
1699
        def masks_box_check(masks, expected, atol=1e-4):
1700
1701
            out = ops.masks_to_boxes(masks)
            assert out.dtype == torch.float
Aditya Oke's avatar
Aditya Oke committed
1702
            torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=True, atol=atol)
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718

        # Check for int type boxes.
        def _get_image():
            assets_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
            mask_path = os.path.join(assets_directory, "masks.tiff")
            image = Image.open(mask_path)
            return image

        def _create_masks(image, masks):
            for index in range(image.n_frames):
                image.seek(index)
                frame = np.array(image)
                masks[index] = torch.tensor(frame)

            return masks

1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
        expected = torch.tensor(
            [
                [127, 2, 165, 40],
                [2, 50, 44, 92],
                [56, 63, 98, 100],
                [139, 68, 175, 104],
                [160, 112, 198, 145],
                [49, 138, 99, 182],
                [108, 148, 152, 213],
            ],
            dtype=torch.float,
        )
1731
1732
1733
1734
1735
1736
1737
1738

        image = _get_image()
        for dtype in [torch.float16, torch.float32, torch.float64]:
            masks = torch.zeros((image.n_frames, image.height, image.width), dtype=dtype)
            masks = _create_masks(image, masks)
            masks_box_check(masks, expected)


1739
class TestStochasticDepth:
1740
    @pytest.mark.parametrize("seed", range(10))
1741
1742
    @pytest.mark.parametrize("p", [0.2, 0.5, 0.8])
    @pytest.mark.parametrize("mode", ["batch", "row"])
1743
1744
    def test_stochastic_depth_random(self, seed, mode, p):
        torch.manual_seed(seed)
1745
1746
1747
        stats = pytest.importorskip("scipy.stats")
        batch_size = 5
        x = torch.ones(size=(batch_size, 3, 4, 4))
1748
        layer = ops.StochasticDepth(p=p, mode=mode)
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
        layer.__repr__()

        trials = 250
        num_samples = 0
        counts = 0
        for _ in range(trials):
            out = layer(x)
            non_zero_count = out.sum(dim=(1, 2, 3)).nonzero().size(0)
            if mode == "batch":
                if non_zero_count == 0:
                    counts += 1
                num_samples += 1
            elif mode == "row":
                counts += batch_size - non_zero_count
                num_samples += batch_size

1765
        p_value = stats.binomtest(counts, num_samples, p=p).pvalue
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
        assert p_value > 0.01

    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("p", (0, 1))
    @pytest.mark.parametrize("mode", ["batch", "row"])
    def test_stochastic_depth(self, seed, mode, p):
        torch.manual_seed(seed)
        batch_size = 5
        x = torch.ones(size=(batch_size, 3, 4, 4))
        layer = ops.StochasticDepth(p=p, mode=mode)

        out = layer(x)
        if p == 0:
            assert out.equal(x)
        elif p == 1:
            assert out.equal(torch.zeros_like(x))
1782

1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
    def make_obj(self, p, mode, wrap=False):
        obj = ops.StochasticDepth(p, mode)
        return StochasticDepthWrapper(obj) if wrap else obj

    @pytest.mark.parametrize("p", (0, 1))
    @pytest.mark.parametrize("mode", ["batch", "row"])
    def test_is_leaf_node(self, p, mode):
        op_obj = self.make_obj(p, mode, wrap=True)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

1797

1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
class TestUtils:
    @pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
    def test_split_normalization_params(self, norm_layer):
        model = models.mobilenet_v3_large(norm_layer=norm_layer)
        params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])

        assert len(params[0]) == 92
        assert len(params[1]) == 82


1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
class TestDropBlock:
    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("dim", [2, 3])
    @pytest.mark.parametrize("p", [0, 0.5])
    @pytest.mark.parametrize("block_size", [5, 11])
    @pytest.mark.parametrize("inplace", [True, False])
    def test_drop_block(self, seed, dim, p, block_size, inplace):
        torch.manual_seed(seed)
        batch_size = 5
        channels = 3
        height = 11
        width = height
        depth = height
        if dim == 2:
            x = torch.ones(size=(batch_size, channels, height, width))
            layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
            feature_size = height * width
        elif dim == 3:
            x = torch.ones(size=(batch_size, channels, depth, height, width))
            layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
            feature_size = depth * height * width
        layer.__repr__()

        out = layer(x)
        if p == 0:
            assert out.equal(x)
        if block_size == height:
            for b, c in product(range(batch_size), range(channels)):
                assert out[b, c].count_nonzero() in (0, feature_size)

    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("dim", [2, 3])
    @pytest.mark.parametrize("p", [0.1, 0.2])
    @pytest.mark.parametrize("block_size", [3])
    @pytest.mark.parametrize("inplace", [False])
    def test_drop_block_random(self, seed, dim, p, block_size, inplace):
        torch.manual_seed(seed)
        batch_size = 5
        channels = 3
        height = 11
        width = height
        depth = height
        if dim == 2:
            x = torch.ones(size=(batch_size, channels, height, width))
            layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
        elif dim == 3:
            x = torch.ones(size=(batch_size, channels, depth, height, width))
            layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)

        trials = 250
        num_samples = 0
        counts = 0
        cell_numel = torch.tensor(x.shape).prod()
        for _ in range(trials):
            with torch.no_grad():
                out = layer(x)
            non_zero_count = out.nonzero().size(0)
            counts += cell_numel - non_zero_count
            num_samples += cell_numel

        assert abs(p - counts / num_samples) / p < 0.15

    def make_obj(self, dim, p, block_size, inplace, wrap=False):
        if dim == 2:
            obj = ops.DropBlock2d(p, block_size, inplace)
        elif dim == 3:
            obj = ops.DropBlock3d(p, block_size, inplace)
        return DropBlockWrapper(obj) if wrap else obj

    @pytest.mark.parametrize("dim", (2, 3))
    @pytest.mark.parametrize("p", [0, 1])
    @pytest.mark.parametrize("block_size", [5, 7])
    @pytest.mark.parametrize("inplace", [True, False])
    def test_is_leaf_node(self, dim, p, block_size, inplace):
        op_obj = self.make_obj(dim, p, block_size, inplace, wrap=True)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs


1890
if __name__ == "__main__":
1891
    pytest.main([__file__])