test_ops.py 76.5 KB
Newer Older
1
import math
2
import os
3
from abc import ABC, abstractmethod
4
from functools import lru_cache
5
from itertools import product
6
from typing import Callable, List, Tuple
7

8
import numpy as np
9
import pytest
10
import torch
11
import torch.fx
12
import torch.nn.functional as F
13
import torch.testing._internal.optests as optests
14
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
15
from PIL import Image
16
from torch import nn, Tensor
17
from torch.autograd import gradcheck
18
from torch.nn.modules.utils import _pair
19
from torchvision import models, ops
20
21
22
from torchvision.models.feature_extraction import get_graph_node_names


23
24
25
26
27
28
29
30
OPTESTS = [
    "test_schema",
    "test_autograd_registration",
    "test_faketensor",
    "test_aot_dispatch_dynamic",
]


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# Context manager for setting deterministic flag and automatically
# resetting it to its original value
class DeterministicGuard:
    def __init__(self, deterministic, *, warn_only=False):
        self.deterministic = deterministic
        self.warn_only = warn_only

    def __enter__(self):
        self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
        self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
        torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)

    def __exit__(self, exception_type, exception_value, traceback):
        torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore)


47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class RoIOpTesterModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 2

    def forward(self, a, b):
        self.layer(a, b)


class MultiScaleRoIAlignModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 3

    def forward(self, a, b, c):
        self.layer(a, b, c)


class DeformConvModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 3

    def forward(self, a, b, c):
        self.layer(a, b, c)


class StochasticDepthWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 1

    def forward(self, a):
        self.layer(a)
85
86


87
88
89
90
91
92
93
94
95
96
class DropBlockWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 1

    def forward(self, a):
        self.layer(a)


97
98
99
100
101
102
103
104
105
class PoolWrapper(nn.Module):
    def __init__(self, pool: nn.Module):
        super().__init__()
        self.pool = pool

    def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:
        return self.pool(imgs, boxes)


106
107
class RoIOpTester(ABC):
    dtype = torch.float64
108
109
    mps_dtype = torch.float32
    mps_backward_atol = 2e-2
110

111
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
112
    @pytest.mark.parametrize("contiguous", (True, False))
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    @pytest.mark.parametrize(
        "x_dtype",
        (
            torch.float16,
            torch.float32,
            torch.float64,
        ),
        ids=str,
    )
    def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
        if device == "mps" and x_dtype is torch.float64:
            pytest.skip("MPS does not support float64")

        rois_dtype = x_dtype if rois_dtype is None else rois_dtype

        tol = 1e-5
        if x_dtype is torch.half:
            if device == "mps":
                tol = 5e-3
            else:
                tol = 4e-3

135
        pool_size = 5
136
        # n_channels % (pool_size ** 2) == 0 required for PS operations.
137
        n_channels = 2 * (pool_size**2)
138
        x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
139
140
        if not contiguous:
            x = x.permute(0, 1, 3, 2)
141
142
143
144
145
        rois = torch.tensor(
            [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],  # format is (xyxy)
            dtype=rois_dtype,
            device=device,
        )
146

147
        pool_h, pool_w = pool_size, pool_size
148
149
        with DeterministicGuard(deterministic):
            y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
150
        # the following should be true whether we're running an autocast test or not.
151
        assert y.dtype == x.dtype
152
        gt_y = self.expected_fn(
153
            x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs
154
        )
155

156
        torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
157

158
    @pytest.mark.parametrize("device", cpu_and_cuda())
159
160
161
162
163
164
165
166
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

167
    @pytest.mark.parametrize("device", cpu_and_cuda())
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.float):
        op_obj = self.make_obj().to(device=device)
        graph_module = torch.fx.symbolic_trace(op_obj)
        pool_size = 5
        n_channels = 2 * (pool_size**2)
        x = torch.rand(2, n_channels, 5, 5, dtype=x_dtype, device=device)
        rois = torch.tensor(
            [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],  # format is (xyxy)
            dtype=rois_dtype,
            device=device,
        )
        output_gt = op_obj(x, rois)
        assert output_gt.dtype == x.dtype
        output_fx = graph_module(x, rois)
        assert output_fx.dtype == x.dtype
        tol = 1e-5
        torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)

186
    @pytest.mark.parametrize("seed", range(10))
187
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
188
    @pytest.mark.parametrize("contiguous", (True, False))
189
    def test_backward(self, seed, device, contiguous, deterministic=False):
190
191
192
        atol = self.mps_backward_atol if device == "mps" else 1e-05
        dtype = self.mps_dtype if device == "mps" else self.dtype

193
        torch.random.manual_seed(seed)
194
        pool_size = 2
195
        x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True)
196
197
        if not contiguous:
            x = x.permute(0, 1, 3, 2)
198
        rois = torch.tensor(
199
            [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device  # format is (xyxy)
200
        )
201

202
203
        def func(z):
            return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
204

205
        script_func = self.get_script_fn(rois, pool_size)
206

207
        with DeterministicGuard(deterministic):
208
209
210
211
212
213
214
215
216
217
218
            gradcheck(func, (x,), atol=atol)

        gradcheck(script_func, (x,), atol=atol)

    @needs_mps
    def test_mps_error_inputs(self):
        pool_size = 2
        x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True)
        rois = torch.tensor(
            [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps"  # format is (xyxy)
        )
219

220
221
222
223
224
225
226
        def func(z):
            return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)

        with pytest.raises(
            RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs."
        ):
            gradcheck(func, (x,))
227

228
    @needs_cuda
229
230
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
231
232
233
    def test_autocast(self, x_dtype, rois_dtype):
        with torch.cuda.amp.autocast():
            self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
234
235
236

    def _helper_boxes_shape(self, func):
        # test boxes as Tensor[N, 5]
237
        with pytest.raises(AssertionError):
238
239
240
241
242
            a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
            boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
            func(a, boxes, output_size=(2, 2))

        # test boxes as List[Tensor[N, 4]]
243
        with pytest.raises(AssertionError):
244
245
246
247
            a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
            boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
            ops.roi_pool(a, [boxes], output_size=(2, 2))

248
249
250
251
252
253
254
255
    def _helper_jit_boxes_list(self, model):
        x = torch.rand(2, 1, 10, 10)
        roi = torch.tensor([[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], dtype=torch.float).t()
        rois = [roi, roi]
        scriped = torch.jit.script(model)
        y = scriped(x, rois)
        assert y.shape == (10, 1, 3, 3)

256
    @abstractmethod
257
258
    def fn(*args, **kwargs):
        pass
259

260
261
262
263
    @abstractmethod
    def make_obj(*args, **kwargs):
        pass

264
    @abstractmethod
265
266
    def get_script_fn(*args, **kwargs):
        pass
267

268
    @abstractmethod
269
270
    def expected_fn(*args, **kwargs):
        pass
271

272

273
class TestRoiPool(RoIOpTester):
274
275
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
276

277
278
279
280
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
        obj = ops.RoIPool((pool_h, pool_w), spatial_scale)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

281
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
282
283
        scriped = torch.jit.script(ops.roi_pool)
        return lambda x: scriped(x, rois, pool_size)
284

285
286
287
    def expected_fn(
        self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
    ):
288
289
        if device is None:
            device = torch.device("cpu")
290

291
292
        n_channels = x.size(1)
        y = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
293

294
295
        def get_slice(k, block):
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))
296

297
298
299
        for roi_idx, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
300
            roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
301

302
303
304
            roi_h, roi_w = roi_x.shape[-2:]
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w
305

306
307
308
309
310
311
            for i in range(0, pool_h):
                for j in range(0, pool_w):
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
                    if bin_x.numel() > 0:
                        y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0]
        return y
312

313
    def test_boxes_shape(self):
314
315
        self._helper_boxes_shape(ops.roi_pool)

316
317
318
319
    def test_jit_boxes_list(self):
        model = PoolWrapper(ops.RoIPool(output_size=[3, 3], spatial_scale=1.0))
        self._helper_jit_boxes_list(model)

320

321
class TestPSRoIPool(RoIOpTester):
322
323
    mps_backward_atol = 5e-2

324
325
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
326

327
328
329
330
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
        obj = ops.PSRoIPool((pool_h, pool_w), spatial_scale)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

331
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
332
333
        scriped = torch.jit.script(ops.ps_roi_pool)
        return lambda x: scriped(x, rois, pool_size)
334

335
336
337
    def expected_fn(
        self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
    ):
338
339
340
        if device is None:
            device = torch.device("cpu")
        n_input_channels = x.size(1)
341
        assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
342
343
344
345
346
347
348
349
350
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
        y = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)

        def get_slice(k, block):
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))

        for roi_idx, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
351
            roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366

            roi_height = max(i_end - i_begin, 1)
            roi_width = max(j_end - j_begin, 1)
            bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w)

            for i in range(0, pool_h):
                for j in range(0, pool_w):
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
                    if bin_x.numel() > 0:
                        area = bin_x.size(-2) * bin_x.size(-1)
                        for c_out in range(0, n_output_channels):
                            c_in = c_out * (pool_h * pool_w) + pool_w * i + j
                            t = torch.sum(bin_x[c_in, :, :])
                            y[roi_idx, c_out, i, j] = t / area
        return y
367

368
    def test_boxes_shape(self):
369
370
        self._helper_boxes_shape(ops.ps_roi_pool)

371

372
373
def bilinear_interpolate(data, y, x, snap_border=False):
    height, width = data.shape
374

375
376
377
378
379
    if snap_border:
        if -1 < y <= 0:
            y = 0
        elif height - 1 <= y < height:
            y = height - 1
380

381
382
383
384
        if -1 < x <= 0:
            x = 0
        elif width - 1 <= x < width:
            x = width - 1
385

386
387
388
389
    y_low = int(math.floor(y))
    x_low = int(math.floor(x))
    y_high = y_low + 1
    x_high = x_low + 1
390

391
392
    wy_h = y - y_low
    wx_h = x - x_low
393
    wy_l = 1 - wy_h
394
    wx_l = 1 - wx_h
395

396
    val = 0
397
398
399
400
    for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
        for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
            if 0 <= yp < height and 0 <= xp < width:
                val += wx * wy * data[yp, xp]
401
    return val
402
403


404
class TestRoIAlign(RoIOpTester):
405
406
    mps_backward_atol = 6e-2

AhnDW's avatar
AhnDW committed
407
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
408
409
410
        return ops.RoIAlign(
            (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
        )(x, rois)
411

412
413
414
415
416
417
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, aligned=False, wrap=False):
        obj = ops.RoIAlign(
            (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
        )
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

418
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
419
420
        scriped = torch.jit.script(ops.roi_align)
        return lambda x: scriped(x, rois, pool_size)
421

422
423
424
425
426
427
428
429
430
431
432
433
    def expected_fn(
        self,
        in_data,
        rois,
        pool_h,
        pool_w,
        spatial_scale=1,
        sampling_ratio=-1,
        aligned=False,
        device=None,
        dtype=torch.float64,
    ):
434
435
        if device is None:
            device = torch.device("cpu")
436
437
438
        n_channels = in_data.size(1)
        out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)

439
        offset = 0.5 if aligned else 0.0
AhnDW's avatar
AhnDW committed
440

441
442
        for r, roi in enumerate(rois):
            batch_idx = int(roi[0])
AhnDW's avatar
AhnDW committed
443
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - offset for x in roi[1:])
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462

            roi_h = i_end - i_begin
            roi_w = j_end - j_begin
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w

            for i in range(0, pool_h):
                start_h = i_begin + i * bin_h
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
                for j in range(0, pool_w):
                    start_w = j_begin + j * bin_w
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))

                    for channel in range(0, n_channels):
                        val = 0
                        for iy in range(0, grid_h):
                            y = start_h + (iy + 0.5) * bin_h / grid_h
                            for ix in range(0, grid_w):
                                x = start_w + (ix + 0.5) * bin_w / grid_w
463
                                val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True)
464
465
466
                        val /= grid_h * grid_w

                        out_data[r, channel, i, j] = val
467
468
        return out_data

469
    def test_boxes_shape(self):
470
471
        self._helper_boxes_shape(ops.roi_align)

472
    @pytest.mark.parametrize("aligned", (True, False))
473
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
474
    @pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64))  # , ids=str)
475
    @pytest.mark.parametrize("contiguous", (True, False))
476
    @pytest.mark.parametrize("deterministic", (True, False))
477
    @pytest.mark.opcheck_only_one()
478
    def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None):
479
480
        if deterministic and device == "cpu":
            pytest.skip("cpu is always deterministic, don't retest")
481
        super().test_forward(
482
483
484
485
486
487
            device=device,
            contiguous=contiguous,
            deterministic=deterministic,
            x_dtype=x_dtype,
            rois_dtype=rois_dtype,
            aligned=aligned,
488
        )
489

490
    @needs_cuda
491
    @pytest.mark.parametrize("aligned", (True, False))
492
    @pytest.mark.parametrize("deterministic", (True, False))
493
494
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
495
    @pytest.mark.opcheck_only_one()
496
    def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
497
        with torch.cuda.amp.autocast():
498
            self.test_forward(
499
500
501
502
503
504
                torch.device("cuda"),
                contiguous=False,
                deterministic=deterministic,
                aligned=aligned,
                x_dtype=x_dtype,
                rois_dtype=rois_dtype,
505
            )
506

507
    @pytest.mark.parametrize("seed", range(10))
508
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
509
510
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("deterministic", (True, False))
511
    @pytest.mark.opcheck_only_one()
512
513
514
515
516
    def test_backward(self, seed, device, contiguous, deterministic):
        if deterministic and device == "cpu":
            pytest.skip("cpu is always deterministic, don't retest")
        super().test_backward(seed, device, contiguous, deterministic)

517
518
519
520
521
522
    def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
        rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
        rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,))  # set batch index
        rois[:, 3:] += rois[:, 1:3]  # make sure boxes aren't degenerate
        return rois

523
524
525
    @pytest.mark.parametrize("aligned", (True, False))
    @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 10), (0.1, 50)))
    @pytest.mark.parametrize("qdtype", (torch.qint8, torch.quint8, torch.qint32))
526
    @pytest.mark.opcheck_only_one()
527
    def test_qroialign(self, aligned, scale, zero_point, qdtype):
528
529
530
531
532
533
534
        """Make sure quantized version of RoIAlign is close to float version"""
        pool_size = 5
        img_size = 10
        n_channels = 2
        num_imgs = 1
        dtype = torch.float

535
536
537
538
539
540
541
        x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype)
        qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qdtype)

        rois = self._make_rois(img_size, num_imgs, dtype)
        qrois = torch.quantize_per_tensor(rois, scale=scale, zero_point=zero_point, dtype=qdtype)

        x, rois = qx.dequantize(), qrois.dequantize()  # we want to pass the same inputs
542

543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
        y = ops.roi_align(
            x,
            rois,
            output_size=pool_size,
            spatial_scale=1,
            sampling_ratio=-1,
            aligned=aligned,
        )
        qy = ops.roi_align(
            qx,
            qrois,
            output_size=pool_size,
            spatial_scale=1,
            sampling_ratio=-1,
            aligned=aligned,
        )

        # The output qy is itself a quantized tensor and there might have been a loss of info when it was
        # quantized. For a fair comparison we need to quantize y as well
        quantized_float_y = torch.quantize_per_tensor(y, scale=scale, zero_point=zero_point, dtype=qdtype)

        try:
            # Ideally, we would assert this, which passes with (scale, zero) == (1, 0)
            assert (qy == quantized_float_y).all()
        except AssertionError:
            # But because the computation aren't exactly the same between the 2 RoIAlign procedures, some
            # rounding error may lead to a difference of 2 in the output.
            # For example with (scale, zero) = (2, 10), 45.00000... will be quantized to 44
            # but 45.00000001 will be rounded to 46. We make sure below that:
            # - such discrepancies between qy and quantized_float_y are very rare (less then 5%)
            # - any difference between qy and quantized_float_y is == scale
            diff_idx = torch.where(qy != quantized_float_y)
            num_diff = diff_idx[0].numel()
576
            assert num_diff / qy.numel() < 0.05
577
578
579
580
581
582
583

            abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize())
            t_scale = torch.full_like(abs_diff, fill_value=scale)
            torch.testing.assert_close(abs_diff, t_scale, rtol=1e-5, atol=1e-5)

    def test_qroi_align_multiple_images(self):
        dtype = torch.float
584
585
        x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype)
        qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8)
586
        rois = self._make_rois(img_size=10, num_imgs=2, dtype=dtype, num_rois=10)
587
        qrois = torch.quantize_per_tensor(rois, scale=1, zero_point=0, dtype=torch.qint8)
588
589
        with pytest.raises(RuntimeError, match="Only one image per batch is allowed"):
            ops.roi_align(qx, qrois, output_size=5)
590

591
592
593
594
    def test_jit_boxes_list(self):
        model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1))
        self._helper_jit_boxes_list(model)

595

596
597
598
599
600
601
602
603
604
optests.generate_opcheck_tests(
    testcase=TestRoIAlign,
    namespaces=["torchvision"],
    failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
    additional_decorators=[],
    test_utils=OPTESTS,
)


605
class TestPSRoIAlign(RoIOpTester):
606
607
    mps_backward_atol = 5e-2

608
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
609
        return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
610

611
612
613
614
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, wrap=False):
        obj = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

615
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
616
617
        scriped = torch.jit.script(ops.ps_roi_align)
        return lambda x: scriped(x, rois, pool_size)
618

619
620
621
    def expected_fn(
        self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, sampling_ratio=-1, dtype=torch.float64
    ):
622
623
        if device is None:
            device = torch.device("cpu")
624
        n_input_channels = in_data.size(1)
625
        assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
        out_data = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)

        for r, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - 0.5 for x in roi[1:])

            roi_h = i_end - i_begin
            roi_w = j_end - j_begin
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w

            for i in range(0, pool_h):
                start_h = i_begin + i * bin_h
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
                for j in range(0, pool_w):
                    start_w = j_begin + j * bin_w
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
                    for c_out in range(0, n_output_channels):
                        c_in = c_out * (pool_h * pool_w) + pool_w * i + j

                        val = 0
                        for iy in range(0, grid_h):
                            y = start_h + (iy + 0.5) * bin_h / grid_h
                            for ix in range(0, grid_w):
                                x = start_w + (ix + 0.5) * bin_w / grid_w
652
                                val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True)
653
654
655
656
                        val /= grid_h * grid_w

                        out_data[r, c_out, i, j] = val
        return out_data
657

658
    def test_boxes_shape(self):
659
660
        self._helper_boxes_shape(ops.ps_roi_align)

661

662
class TestMultiScaleRoIAlign:
663
664
665
666
667
668
    def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False):
        if fmap_names is None:
            fmap_names = ["0"]
        obj = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)
        return MultiScaleRoIAlignModuleWrapper(obj) if wrap else obj

669
    def test_msroialign_repr(self):
670
        fmap_names = ["0"]
671
672
673
        output_size = (7, 7)
        sampling_ratio = 2
        # Pass mock feature map names
674
        t = self.make_obj(fmap_names, output_size, sampling_ratio, wrap=False)
675
676

        # Check integrity of object __repr__ attribute
677
678
679
680
        expected_string = (
            f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, "
            f"sampling_ratio={sampling_ratio})"
        )
681
        assert repr(t) == expected_string
682

683
    @pytest.mark.parametrize("device", cpu_and_cuda())
684
685
686
687
688
689
690
691
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

692

693
694
class TestNMS:
    def _reference_nms(self, boxes, scores, iou_threshold):
695
696
        """
        Args:
697
698
699
            boxes: boxes in corner-form
            scores: probabilities
            iou_threshold: intersection over union threshold
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
        Returns:
             picked: a list of indexes of the kept boxes
        """
        picked = []
        _, indexes = scores.sort(descending=True)
        while len(indexes) > 0:
            current = indexes[0]
            picked.append(current.item())
            if len(indexes) == 1:
                break
            current_box = boxes[current, :]
            indexes = indexes[1:]
            rest_boxes = boxes[indexes, :]
            iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1)
            indexes = indexes[iou <= iou_threshold]

        return torch.as_tensor(picked)

718
719
720
721
722
    def _create_tensors_with_iou(self, N, iou_thresh):
        # force last box to have a pre-defined iou with the first box
        # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
        # then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
        # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
723
724
725
        # Adjust the threshold upward a bit with the intent of creating
        # at least one box that exceeds (barely) the threshold and so
        # should be suppressed.
726
        boxes = torch.rand(N, 4) * 100
727
728
729
        boxes[:, 2:] += boxes[:, :2]
        boxes[-1, :] = boxes[0, :]
        x0, y0, x1, y1 = boxes[-1].tolist()
730
        iou_thresh += 1e-5
731
        boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
732
733
734
        scores = torch.rand(N)
        return boxes, scores

735
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
736
    @pytest.mark.parametrize("seed", range(10))
737
    @pytest.mark.opcheck_only_one()
738
739
    def test_nms_ref(self, iou, seed):
        torch.random.manual_seed(seed)
740
        err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
741
742
743
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        keep_ref = self._reference_nms(boxes, scores, iou)
        keep = ops.nms(boxes, scores, iou)
744
        torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou))
745
746
747
748
749
750
751
752
753
754
755

    def test_nms_input_errors(self):
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(4), torch.rand(3), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 5), torch.rand(3), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 4), torch.rand(4), 0.5)

756
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
757
    @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10)))
758
    @pytest.mark.opcheck_only_one()
759
    def test_qnms(self, iou, scale, zero_point):
760
        # Note: we compare qnms vs nms instead of qnms vs reference implementation.
761
        # This is because with the int conversion, the trick used in _create_tensors_with_iou
762
        # doesn't really work (in fact, nms vs reference implem will also fail with ints)
763
        err_msg = "NMS and QNMS give different results for IoU={}"
764
        boxes, scores = self._create_tensors_with_iou(1000, iou)
765
        scores *= 100  # otherwise most scores would be 0 or 1 after int conversion
766

767
768
        qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point, dtype=torch.quint8)
        qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point, dtype=torch.quint8)
769

770
771
        boxes = qboxes.dequantize()
        scores = qscores.dequantize()
772

773
774
        keep = ops.nms(boxes, scores, iou)
        qkeep = ops.nms(qboxes, qscores, iou)
775

776
        torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
777

778
779
780
781
782
783
784
    @pytest.mark.parametrize(
        "device",
        (
            pytest.param("cuda", marks=pytest.mark.needs_cuda),
            pytest.param("mps", marks=pytest.mark.needs_mps),
        ),
    )
785
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
786
    @pytest.mark.opcheck_only_one()
787
788
    def test_nms_gpu(self, iou, device, dtype=torch.float64):
        dtype = torch.float32 if device == "mps" else dtype
789
        tol = 1e-3 if dtype is torch.half else 1e-5
790
        err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
791

792
793
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        r_cpu = ops.nms(boxes, scores, iou)
794
        r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)
795

796
        is_eq = torch.allclose(r_cpu, r_gpu.cpu())
797
798
799
        if not is_eq:
            # if the indices are not the same, ensure that it's because the scores
            # are duplicate
800
            is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
801
802
803
        assert is_eq, err_msg.format(iou)

    @needs_cuda
804
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
805
    @pytest.mark.parametrize("dtype", (torch.float, torch.half))
806
    @pytest.mark.opcheck_only_one()
807
808
    def test_autocast(self, iou, dtype):
        with torch.cuda.amp.autocast():
809
810
811
812
813
814
815
816
817
            self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")

    @pytest.mark.parametrize(
        "device",
        (
            pytest.param("cuda", marks=pytest.mark.needs_cuda),
            pytest.param("mps", marks=pytest.mark.needs_mps),
        ),
    )
818
    @pytest.mark.opcheck_only_one()
819
    def test_nms_float16(self, device):
820
821
822
823
824
825
        boxes = torch.tensor(
            [
                [285.3538, 185.5758, 1193.5110, 851.4551],
                [285.1472, 188.7374, 1192.4984, 851.0669],
                [279.2440, 197.9812, 1189.4746, 849.2019],
            ]
826
827
        ).to(device)
        scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device)
828
829
830
831

        iou_thres = 0.2
        keep32 = ops.nms(boxes, scores, iou_thres)
        keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
832
        assert_equal(keep32, keep16)
833

834
    @pytest.mark.parametrize("seed", range(10))
835
    @pytest.mark.opcheck_only_one()
836
    def test_batched_nms_implementations(self, seed):
837
        """Make sure that both implementations of batched_nms yield identical results"""
838
        torch.random.manual_seed(seed)
839
840

        num_boxes = 1000
841
        iou_threshold = 0.9
842
843
844
845
846
847
848
849
850
851

        boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
        assert max(boxes[:, 0]) < min(boxes[:, 2])  # x1 < x2
        assert max(boxes[:, 1]) < min(boxes[:, 3])  # y1 < y2

        scores = torch.rand(num_boxes)
        idxs = torch.randint(0, 4, size=(num_boxes,))
        keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
        keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)

852
853
854
        torch.testing.assert_close(
            keep_vanilla, keep_trick, msg="The vanilla and the trick implementation yield different nms outputs."
        )
855
856
857

        # Also make sure an empty tensor is returned if boxes is empty
        empty = torch.empty((0,), dtype=torch.int64)
858
        torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None))
859

860

861
862
863
864
865
866
867
868
869
optests.generate_opcheck_tests(
    testcase=TestNMS,
    namespaces=["torchvision"],
    failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
    additional_decorators=[],
    test_utils=OPTESTS,
)


870
871
872
class TestDeformConv:
    dtype = torch.float64

873
    def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
        stride_h, stride_w = _pair(stride)
        pad_h, pad_w = _pair(padding)
        dil_h, dil_w = _pair(dilation)
        weight_h, weight_w = weight.shape[-2:]

        n_batches, n_in_channels, in_h, in_w = x.shape
        n_out_channels = weight.shape[0]

        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

        n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w)
        in_c_per_offset_grp = n_in_channels // n_offset_grps

        n_weight_grps = n_in_channels // weight.shape[1]
        in_c_per_weight_grp = weight.shape[1]
        out_c_per_weight_grp = n_out_channels // n_weight_grps

        out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
        for b in range(n_batches):
            for c_out in range(n_out_channels):
                for i in range(out_h):
                    for j in range(out_w):
                        for di in range(weight_h):
                            for dj in range(weight_w):
                                for c in range(in_c_per_weight_grp):
                                    weight_grp = c_out // out_c_per_weight_grp
                                    c_in = weight_grp * in_c_per_weight_grp + c

                                    offset_grp = c_in // in_c_per_offset_grp
904
905
                                    mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj
                                    offset_idx = 2 * mask_idx
906
907
908
909

                                    pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
                                    pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]

910
911
912
913
                                    mask_value = 1.0
                                    if mask is not None:
                                        mask_value = mask[b, mask_idx, i, j]

914
915
916
917
918
                                    out[b, c_out, i, j] += (
                                        mask_value
                                        * weight[c_out, c, di, dj]
                                        * bilinear_interpolate(x[b, c_in, :, :], pi, pj)
                                    )
919
920
921
        out += bias.view(1, n_out_channels, 1, 1)
        return out

922
    @lru_cache(maxsize=None)
923
    def get_fn_args(self, device, contiguous, batch_sz, dtype):
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
        n_in_channels = 6
        n_out_channels = 2
        n_weight_grps = 2
        n_offset_grps = 3

        stride = (2, 1)
        pad = (1, 0)
        dilation = (2, 1)

        stride_h, stride_w = stride
        pad_h, pad_w = pad
        dil_h, dil_w = dilation
        weight_h, weight_w = (3, 2)
        in_h, in_w = (5, 4)

        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

942
        x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True)
943

944
945
946
947
948
949
950
951
952
        offset = torch.randn(
            batch_sz,
            n_offset_grps * 2 * weight_h * weight_w,
            out_h,
            out_w,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
953

954
955
956
        mask = torch.randn(
            batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, device=device, dtype=dtype, requires_grad=True
        )
957

958
959
960
961
962
963
964
965
966
        weight = torch.randn(
            n_out_channels,
            n_in_channels // n_weight_grps,
            weight_h,
            weight_w,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
967

968
        bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True)
969
970
971
972

        if not contiguous:
            x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
            offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
973
            mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
974
975
            weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)

976
        return x, weight, offset, mask, bias, stride, pad, dilation
977

978
979
980
981
982
983
    def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, wrap=False):
        obj = ops.DeformConv2d(
            in_channels, out_channels, kernel_size, stride=(2, 1), padding=(1, 0), dilation=(2, 1), groups=groups
        )
        return DeformConvModuleWrapper(obj) if wrap else obj

984
    @pytest.mark.parametrize("device", cpu_and_cuda())
985
986
987
988
989
990
991
992
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

993
    @pytest.mark.parametrize("device", cpu_and_cuda())
994
995
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("batch_sz", (0, 33))
996
997
    def test_forward(self, device, contiguous, batch_sz, dtype=None):
        dtype = dtype or self.dtype
998
        x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
999
1000
1001
1002
        in_channels = 6
        out_channels = 2
        kernel_size = (3, 2)
        groups = 2
Nicolas Hug's avatar
Nicolas Hug committed
1003
        tol = 2e-3 if dtype is torch.half else 1e-5
1004

1005
1006
1007
        layer = self.make_obj(in_channels, out_channels, kernel_size, groups, wrap=False).to(
            device=x.device, dtype=dtype
        )
1008
        res = layer(x, offset, mask)
1009
1010
1011

        weight = layer.weight.data
        bias = layer.bias.data
1012
1013
        expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)

1014
        torch.testing.assert_close(
1015
            res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1016
        )
1017
1018
1019
1020

        # no modulation test
        res = layer(x, offset)
        expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
1021

1022
        torch.testing.assert_close(
1023
            res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1024
        )
1025

1026
1027
1028
1029
1030
    def test_wrong_sizes(self):
        in_channels = 6
        out_channels = 2
        kernel_size = (3, 2)
        groups = 2
1031
1032
1033
1034
1035
1036
        x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(
            "cpu", contiguous=True, batch_sz=10, dtype=self.dtype
        )
        layer = ops.DeformConv2d(
            in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups
        )
1037
        with pytest.raises(RuntimeError, match="the shape of the offset"):
1038
            wrong_offset = torch.rand_like(offset[:, :2])
1039
            layer(x, wrong_offset)
1040

1041
        with pytest.raises(RuntimeError, match=r"mask.shape\[1\] is not valid"):
1042
            wrong_mask = torch.rand_like(mask[:, :2])
1043
            layer(x, offset, wrong_mask)
1044

1045
    @pytest.mark.parametrize("device", cpu_and_cuda())
1046
1047
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("batch_sz", (0, 33))
1048
    def test_backward(self, device, contiguous, batch_sz):
1049
1050
1051
        x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
            device, contiguous, batch_sz, self.dtype
        )
1052
1053

        def func(x_, offset_, mask_, weight_, bias_):
1054
1055
1056
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_
            )
1057

1058
        gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True)
1059
1060

        def func_no_mask(x_, offset_, weight_, bias_):
1061
1062
1063
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None
            )
1064

1065
        gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True)
1066
1067
1068
1069

        @torch.jit.script
        def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
            # type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=mask_
            )

        gradcheck(
            lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
            (x, offset, mask, weight, bias),
            nondet_tol=1e-5,
            fast_mode=True,
        )
1080
1081

        @torch.jit.script
1082
1083
        def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
            # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=None
            )

        gradcheck(
            lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
            (x, offset, weight, bias),
            nondet_tol=1e-5,
            fast_mode=True,
        )
1094

1095
    @needs_cuda
1096
    @pytest.mark.parametrize("contiguous", (True, False))
1097
    def test_compare_cpu_cuda_grads(self, contiguous):
1098
1099
1100
        # Test from https://github.com/pytorch/vision/issues/2598
        # Run on CUDA only

1101
1102
        # compare grads computed on CUDA with grads computed on CPU
        true_cpu_grads = None
1103

1104
1105
1106
1107
        init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
        img = torch.randn(8, 9, 1000, 110)
        offset = torch.rand(8, 2 * 3 * 3, 1000, 110)
        mask = torch.rand(8, 3 * 3, 1000, 110)
1108

1109
1110
1111
1112
1113
1114
1115
        if not contiguous:
            img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
            offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
            mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
            weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
        else:
            weight = init_weight
1116

1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
        for d in ["cpu", "cuda"]:
            out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
            out.mean().backward()
            if true_cpu_grads is None:
                true_cpu_grads = init_weight.grad
                assert true_cpu_grads is not None
            else:
                assert init_weight.grad is not None
                res_grads = init_weight.grad.to("cpu")
                torch.testing.assert_close(true_cpu_grads, res_grads)

    @needs_cuda
1129
1130
    @pytest.mark.parametrize("batch_sz", (0, 33))
    @pytest.mark.parametrize("dtype", (torch.float, torch.half))
1131
1132
1133
1134
    def test_autocast(self, batch_sz, dtype):
        with torch.cuda.amp.autocast():
            self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)

1135
1136
1137
1138
    def test_forward_scriptability(self):
        # Non-regression test for https://github.com/pytorch/vision/issues/4078
        torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))

1139
1140

class TestFrozenBNT:
1141
1142
    def test_frozenbatchnorm2d_repr(self):
        num_features = 32
1143
1144
        eps = 1e-5
        t = ops.misc.FrozenBatchNorm2d(num_features, eps=eps)
1145
1146

        # Check integrity of object __repr__ attribute
1147
        expected_string = f"FrozenBatchNorm2d({num_features}, eps={eps})"
1148
        assert repr(t) == expected_string
1149

1150
1151
1152
    @pytest.mark.parametrize("seed", range(10))
    def test_frozenbatchnorm2d_eps(self, seed):
        torch.random.manual_seed(seed)
1153
1154
        sample_size = (4, 32, 28, 28)
        x = torch.rand(sample_size)
1155
1156
1157
1158
1159
1160
1161
        state_dict = dict(
            weight=torch.rand(sample_size[1]),
            bias=torch.rand(sample_size[1]),
            running_mean=torch.rand(sample_size[1]),
            running_var=torch.rand(sample_size[1]),
            num_batches_tracked=torch.tensor(100),
        )
1162

1163
        # Check that default eps is equal to the one of BN
1164
1165
        fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
        fbn.load_state_dict(state_dict, strict=False)
1166
        bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
1167
1168
        bn.load_state_dict(state_dict)
        # Difference is expected to fall in an acceptable range
1169
        torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
1170
1171
1172
1173
1174
1175

        # Check computation for eps > 0
        fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
        fbn.load_state_dict(state_dict, strict=False)
        bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
        bn.load_state_dict(state_dict)
1176
        torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
1177

1178

Aditya Oke's avatar
Aditya Oke committed
1179
class TestBoxConversionToRoi:
1180
1181
1182
    def _get_box_sequences():
        # Define here the argument type of `boxes` supported by region pooling operations
        box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float)
1183
1184
1185
1186
        box_list = [
            torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
            torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
        ]
1187
1188
1189
        box_tuple = tuple(box_list)
        return box_tensor, box_list, box_tuple

1190
    @pytest.mark.parametrize("box_sequence", _get_box_sequences())
1191
    def test_check_roi_boxes_shape(self, box_sequence):
1192
        # Ensure common sequences of tensors are supported
1193
        ops._utils.check_roi_boxes_shape(box_sequence)
1194

1195
    @pytest.mark.parametrize("box_sequence", _get_box_sequences())
1196
    def test_convert_boxes_to_roi_format(self, box_sequence):
1197
1198
        # Ensure common sequences of tensors yield the same result
        ref_tensor = None
1199
1200
1201
1202
        if ref_tensor is None:
            ref_tensor = box_sequence
        else:
            assert_equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence))
1203
1204


Aditya Oke's avatar
Aditya Oke committed
1205
class TestBoxConvert:
1206
    def test_bbox_same(self):
1207
1208
1209
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
1210

1211
        exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
1212

1213
1214
1215
1216
        assert exp_xyxy.size() == torch.Size([4, 4])
        assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy)
        assert_equal(ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh"), exp_xyxy)
        assert_equal(ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh"), exp_xyxy)
1217
1218
1219
1220

    def test_bbox_xyxy_xywh(self):
        # Simple test convert boxes to xywh and back. Make sure they are same.
        # box_tensor is in x1 y1 x2 y2 format.
1221
1222
1223
1224
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
        exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
1225

1226
        assert exp_xywh.size() == torch.Size([4, 4])
1227
        box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
1228
        assert_equal(box_xywh, exp_xywh)
1229
1230
1231

        # Reverse conversion
        box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy")
1232
        assert_equal(box_xyxy, box_tensor)
1233
1234

    def test_bbox_xyxy_cxcywh(self):
Aditya Oke's avatar
Aditya Oke committed
1235
        # Simple test convert boxes to cxcywh and back. Make sure they are same.
1236
        # box_tensor is in x1 y1 x2 y2 format.
1237
1238
1239
1240
1241
1242
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
        exp_cxcywh = torch.tensor(
            [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
        )
1243

1244
        assert exp_cxcywh.size() == torch.Size([4, 4])
1245
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
1246
        assert_equal(box_cxcywh, exp_cxcywh)
1247
1248
1249

        # Reverse conversion
        box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy")
1250
        assert_equal(box_xyxy, box_tensor)
1251
1252

    def test_bbox_xywh_cxcywh(self):
1253
1254
1255
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
        )
1256

1257
1258
1259
        exp_cxcywh = torch.tensor(
            [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
        )
1260

1261
        assert exp_cxcywh.size() == torch.Size([4, 4])
1262
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh")
1263
        assert_equal(box_cxcywh, exp_cxcywh)
1264
1265
1266

        # Reverse conversion
        box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
1267
        assert_equal(box_xywh, box_tensor)
1268

1269
1270
    @pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh"])
    @pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy"])
1271
    def test_bbox_invalid(self, inv_infmt, inv_outfmt):
1272
1273
1274
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
        )
1275

1276
1277
        with pytest.raises(ValueError):
            ops.box_convert(box_tensor, inv_infmt, inv_outfmt)
1278
1279

    def test_bbox_convert_jit(self):
1280
1281
1282
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
1283

1284
        scripted_fn = torch.jit.script(ops.box_convert)
1285

1286
        box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
1287
        scripted_xywh = scripted_fn(box_tensor, "xyxy", "xywh")
Aditya Oke's avatar
Aditya Oke committed
1288
        torch.testing.assert_close(scripted_xywh, box_xywh)
1289

1290
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
1291
        scripted_cxcywh = scripted_fn(box_tensor, "xyxy", "cxcywh")
Aditya Oke's avatar
Aditya Oke committed
1292
        torch.testing.assert_close(scripted_cxcywh, box_cxcywh)
1293
1294


Aditya Oke's avatar
Aditya Oke committed
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
class TestBoxArea:
    def area_check(self, box, expected, atol=1e-4):
        out = ops.box_area(box)
        torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)

    @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
    def test_int_boxes(self, dtype):
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
        expected = torch.tensor([10000, 0], dtype=torch.int32)
        self.area_check(box_tensor, expected)

    @pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
    def test_float_boxes(self, dtype):
        box_tensor = torch.tensor(FLOAT_BOXES, dtype=dtype)
        expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
        self.area_check(box_tensor, expected)

    def test_float16_box(self):
        box_tensor = torch.tensor(
            [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16
        )

        expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
        self.area_check(box_tensor, expected, atol=0.01)
1319

Aditya Oke's avatar
Aditya Oke committed
1320
1321
1322
1323
1324
1325
    def test_box_area_jit(self):
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
        expected = ops.box_area(box_tensor)
        scripted_fn = torch.jit.script(ops.box_area)
        scripted_area = scripted_fn(box_tensor)
        torch.testing.assert_close(scripted_area, expected)
1326

Aditya Oke's avatar
Aditya Oke committed
1327

1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
FLOAT_BOXES = [
    [285.3538, 185.5758, 1193.5110, 851.4551],
    [285.1472, 188.7374, 1192.4984, 851.0669],
    [279.2440, 197.9812, 1189.4746, 849.2019],
]


def gen_box(size, dtype=torch.float):
    xy1 = torch.rand((size, 2), dtype=dtype)
    xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
    return torch.cat([xy1, xy2], axis=-1)


Aditya Oke's avatar
Aditya Oke committed
1343
1344
class TestIouBase:
    @staticmethod
1345
    def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
1346
        for dtype in dtypes:
1347
1348
            actual_box1 = torch.tensor(actual_box1, dtype=dtype)
            actual_box2 = torch.tensor(actual_box2, dtype=dtype)
1349
            expected_box = torch.tensor(expected)
1350
            out = target_fn(actual_box1, actual_box2)
Aditya Oke's avatar
Aditya Oke committed
1351
            torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
Aditya Oke's avatar
Aditya Oke committed
1352

Aditya Oke's avatar
Aditya Oke committed
1353
    @staticmethod
1354
1355
    def _run_jit_test(target_fn: Callable, actual_box: List):
        box_tensor = torch.tensor(actual_box, dtype=torch.float)
Aditya Oke's avatar
Aditya Oke committed
1356
1357
1358
1359
        expected = target_fn(box_tensor, box_tensor)
        scripted_fn = torch.jit.script(target_fn)
        scripted_out = scripted_fn(box_tensor, box_tensor)
        torch.testing.assert_close(scripted_out, expected)
Aditya Oke's avatar
Aditya Oke committed
1360

1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
    @staticmethod
    def _cartesian_product(boxes1, boxes2, target_fn: Callable):
        N = boxes1.size(0)
        M = boxes2.size(0)
        result = torch.zeros((N, M))
        for i in range(N):
            for j in range(M):
                result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
        return result

    @staticmethod
    def _run_cartesian_test(target_fn: Callable):
        boxes1 = gen_box(5)
        boxes2 = gen_box(7)
        a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
        b = target_fn(boxes1, boxes2)
1377
        torch.testing.assert_close(a, b)
1378

1379

Aditya Oke's avatar
Aditya Oke committed
1380
class TestBoxIou(TestIouBase):
1381
    int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
Aditya Oke's avatar
Aditya Oke committed
1382
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
Aditya Oke's avatar
Aditya Oke committed
1383

Aditya Oke's avatar
Aditya Oke committed
1384
    @pytest.mark.parametrize(
1385
        "actual_box1, actual_box2, dtypes, atol, expected",
Aditya Oke's avatar
Aditya Oke committed
1386
        [
1387
1388
1389
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
Aditya Oke's avatar
Aditya Oke committed
1390
1391
        ],
    )
1392
1393
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)
Aditya Oke's avatar
Aditya Oke committed
1394

Aditya Oke's avatar
Aditya Oke committed
1395
1396
    def test_iou_jit(self):
        self._run_jit_test(ops.box_iou, INT_BOXES)
Aditya Oke's avatar
Aditya Oke committed
1397

1398
1399
1400
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.box_iou)

1401

Aditya Oke's avatar
Aditya Oke committed
1402
class TestGeneralizedBoxIou(TestIouBase):
1403
    int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
Aditya Oke's avatar
Aditya Oke committed
1404
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1405
1406

    @pytest.mark.parametrize(
1407
        "actual_box1, actual_box2, dtypes, atol, expected",
1408
        [
1409
1410
1411
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1412
1413
        ],
    )
1414
1415
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.generalized_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1416

Aditya Oke's avatar
Aditya Oke committed
1417
1418
    def test_iou_jit(self):
        self._run_jit_test(ops.generalized_box_iou, INT_BOXES)
1419

1420
1421
1422
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.generalized_box_iou)

1423

Aditya Oke's avatar
Aditya Oke committed
1424
class TestDistanceBoxIoU(TestIouBase):
1425
1426
1427
1428
1429
1430
    int_expected = [
        [1.0000, 0.1875, -0.4444],
        [0.1875, 1.0000, -0.5625],
        [-0.4444, -0.5625, 1.0000],
        [-0.0781, 0.1875, -0.6267],
    ]
Aditya Oke's avatar
Aditya Oke committed
1431
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1432

Aditya Oke's avatar
Aditya Oke committed
1433
    @pytest.mark.parametrize(
1434
        "actual_box1, actual_box2, dtypes, atol, expected",
Aditya Oke's avatar
Aditya Oke committed
1435
        [
1436
1437
1438
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
Aditya Oke's avatar
Aditya Oke committed
1439
1440
        ],
    )
1441
1442
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.distance_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1443

Aditya Oke's avatar
Aditya Oke committed
1444
1445
    def test_iou_jit(self):
        self._run_jit_test(ops.distance_box_iou, INT_BOXES)
1446

1447
1448
1449
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.distance_box_iou)

1450

Aditya Oke's avatar
Aditya Oke committed
1451
class TestCompleteBoxIou(TestIouBase):
1452
1453
1454
1455
1456
1457
    int_expected = [
        [1.0000, 0.1875, -0.4444],
        [0.1875, 1.0000, -0.5625],
        [-0.4444, -0.5625, 1.0000],
        [-0.0781, 0.1875, -0.6267],
    ]
Aditya Oke's avatar
Aditya Oke committed
1458
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1459
1460

    @pytest.mark.parametrize(
1461
        "actual_box1, actual_box2, dtypes, atol, expected",
1462
        [
1463
1464
1465
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1466
1467
        ],
    )
1468
1469
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.complete_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1470

Aditya Oke's avatar
Aditya Oke committed
1471
1472
    def test_iou_jit(self):
        self._run_jit_test(ops.complete_box_iou, INT_BOXES)
1473

1474
1475
1476
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.complete_box_iou)

1477

Aditya Oke's avatar
Aditya Oke committed
1478
1479
1480
1481
1482
def get_boxes(dtype, device):
    box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
    box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
    box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
    box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)
1483

Aditya Oke's avatar
Aditya Oke committed
1484
1485
    box1s = torch.stack([box2, box2], dim=0)
    box2s = torch.stack([box3, box4], dim=0)
1486

Aditya Oke's avatar
Aditya Oke committed
1487
    return box1, box2, box3, box4, box1s, box2s
1488

Aditya Oke's avatar
Aditya Oke committed
1489

Aditya Oke's avatar
Aditya Oke committed
1490
1491
1492
1493
def assert_iou_loss(iou_fn, box1, box2, expected_loss, device, reduction="none"):
    computed_loss = iou_fn(box1, box2, reduction=reduction)
    expected_loss = torch.tensor(expected_loss, device=device)
    torch.testing.assert_close(computed_loss, expected_loss)
1494
1495


Aditya Oke's avatar
Aditya Oke committed
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
def assert_empty_loss(iou_fn, dtype, device):
    box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
    box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
    loss = iou_fn(box1, box2, reduction="mean")
    loss.backward()
    torch.testing.assert_close(loss, torch.tensor(0.0, device=device))
    assert box1.grad is not None, "box1.grad should not be None after backward is called"
    assert box2.grad is not None, "box2.grad should not be None after backward is called"
    loss = iou_fn(box1, box2, reduction="none")
    assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty"
Aditya Oke's avatar
Aditya Oke committed
1506

Aditya Oke's avatar
Aditya Oke committed
1507

Aditya Oke's avatar
Aditya Oke committed
1508
1509
class TestGeneralizedBoxIouLoss:
    # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py
1510
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1511
1512
1513
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_giou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1514

Aditya Oke's avatar
Aditya Oke committed
1515
1516
        # Identical boxes should have loss of 0
        assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1517

Aditya Oke's avatar
Aditya Oke committed
1518
1519
        # quarter size box inside other box = IoU of 0.25
        assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1520

Aditya Oke's avatar
Aditya Oke committed
1521
1522
1523
        # Two side by side boxes, area=union
        # IoU=0 and GIoU=0 (loss 1.0)
        assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1524

Aditya Oke's avatar
Aditya Oke committed
1525
1526
1527
        # Two diagonally adjacent boxes, area=2*union
        # IoU=0 and GIoU=-0.5 (loss 1.5)
        assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1528

Aditya Oke's avatar
Aditya Oke committed
1529
1530
1531
        # Test batched loss and reductions
        assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, device=device, reduction="sum")
        assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, device=device, reduction="mean")
Yassine Alouini's avatar
Yassine Alouini committed
1532

1533
1534
1535
1536
1537
        # Test reduction value
        # reduction value other than ["none", "mean", "sum"] should raise a ValueError
        with pytest.raises(ValueError, match="Invalid"):
            ops.generalized_box_iou_loss(box1s, box2s, reduction="xyz")

1538
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1539
1540
1541
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_inputs(self, dtype, device):
        assert_empty_loss(ops.generalized_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1542
1543


Aditya Oke's avatar
Aditya Oke committed
1544
1545
class TestCompleteBoxIouLoss:
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1546
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1547
1548
    def test_ciou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1549

Aditya Oke's avatar
Aditya Oke committed
1550
1551
1552
1553
1554
1555
        assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
        assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
Yassine Alouini's avatar
Yassine Alouini committed
1556

1557
1558
1559
        with pytest.raises(ValueError, match="Invalid"):
            ops.complete_box_iou_loss(box1s, box2s, reduction="xyz")

1560
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1561
1562
1563
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_inputs(self, dtype, device):
        assert_empty_loss(ops.complete_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1564
1565


Aditya Oke's avatar
Aditya Oke committed
1566
class TestDistanceBoxIouLoss:
1567
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1568
1569
1570
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_distance_iou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1571

Aditya Oke's avatar
Aditya Oke committed
1572
1573
1574
1575
1576
1577
        assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
        assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
Yassine Alouini's avatar
Yassine Alouini committed
1578

1579
1580
1581
        with pytest.raises(ValueError, match="Invalid"):
            ops.distance_box_iou_loss(box1s, box2s, reduction="xyz")

1582
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1583
1584
1585
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_distance_iou_inputs(self, dtype, device):
        assert_empty_loss(ops.distance_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1586
1587


Aditya Oke's avatar
Aditya Oke committed
1588
1589
1590
1591
class TestFocalLoss:
    def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs):
        def logit(p):
            return torch.log(p / (1 - p))
Yassine Alouini's avatar
Yassine Alouini committed
1592

Aditya Oke's avatar
Aditya Oke committed
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
        def generate_tensor_with_range_type(shape, range_type, **kwargs):
            if range_type != "random_binary":
                low, high = {
                    "small": (0.0, 0.2),
                    "big": (0.8, 1.0),
                    "zeros": (0.0, 0.0),
                    "ones": (1.0, 1.0),
                    "random": (0.0, 1.0),
                }[range_type]
                return torch.testing.make_tensor(shape, low=low, high=high, **kwargs)
            else:
                return torch.randint(0, 2, shape, **kwargs)
Yassine Alouini's avatar
Yassine Alouini committed
1605

Aditya Oke's avatar
Aditya Oke committed
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
        # This function will return inputs and targets with shape: (shape[0]*9, shape[1])
        inputs = []
        targets = []
        for input_range_type, target_range_type in [
            ("small", "zeros"),
            ("small", "ones"),
            ("small", "random_binary"),
            ("big", "zeros"),
            ("big", "ones"),
            ("big", "random_binary"),
            ("random", "zeros"),
            ("random", "ones"),
            ("random", "random_binary"),
        ]:
            inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs)))
            targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs))
Yassine Alouini's avatar
Yassine Alouini committed
1622

Aditya Oke's avatar
Aditya Oke committed
1623
        return torch.cat(inputs), torch.cat(targets)
Yassine Alouini's avatar
Yassine Alouini committed
1624

Aditya Oke's avatar
Aditya Oke committed
1625
1626
    @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
    @pytest.mark.parametrize("gamma", [0, 2])
1627
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [0, 1])
    def test_correct_ratio(self, alpha, gamma, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        # For testing the ratio with manual calculation, we require the reduction to be "none"
        reduction = "none"
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)
Yassine Alouini's avatar
Yassine Alouini committed
1639

Aditya Oke's avatar
Aditya Oke committed
1640
1641
1642
        assert torch.all(
            focal_loss <= ce_loss
        ), "focal loss must be less or equal to cross entropy loss with same input"
Abhijit Deo's avatar
Abhijit Deo committed
1643

Aditya Oke's avatar
Aditya Oke committed
1644
1645
1646
1647
1648
1649
1650
        loss_ratio = (focal_loss / ce_loss).squeeze()
        prob = torch.sigmoid(inputs)
        p_t = prob * targets + (1 - prob) * (1 - targets)
        correct_ratio = (1.0 - p_t) ** gamma
        if alpha >= 0:
            alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
            correct_ratio = correct_ratio * alpha_t
Abhijit Deo's avatar
Abhijit Deo committed
1651

Aditya Oke's avatar
Aditya Oke committed
1652
1653
        tol = 1e-3 if dtype is torch.half else 1e-5
        torch.testing.assert_close(correct_ratio, loss_ratio, atol=tol, rtol=tol)
Abhijit Deo's avatar
Abhijit Deo committed
1654

Aditya Oke's avatar
Aditya Oke committed
1655
    @pytest.mark.parametrize("reduction", ["mean", "sum"])
1656
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [2, 3])
    def test_equal_ce_loss(self, reduction, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        # focal loss should be equal ce_loss if alpha=-1 and gamma=0
        alpha = -1
        gamma = 0
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        inputs_fl = inputs.clone().requires_grad_()
        targets_fl = targets.clone()
        inputs_ce = inputs.clone().requires_grad_()
        targets_ce = targets.clone()
        focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction)
        ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction)
Abhijit Deo's avatar
Abhijit Deo committed
1673

Aditya Oke's avatar
Aditya Oke committed
1674
        torch.testing.assert_close(focal_loss, ce_loss)
Abhijit Deo's avatar
Abhijit Deo committed
1675

Aditya Oke's avatar
Aditya Oke committed
1676
1677
1678
        focal_loss.backward()
        ce_loss.backward()
        torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad)
Abhijit Deo's avatar
Abhijit Deo committed
1679

Aditya Oke's avatar
Aditya Oke committed
1680
1681
1682
    @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
    @pytest.mark.parametrize("gamma", [0, 2])
    @pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
1683
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1684
1685
1686
1687
1688
1689
1690
1691
1692
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [4, 5])
    def test_jit(self, alpha, gamma, reduction, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        script_fn = torch.jit.script(ops.sigmoid_focal_loss)
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1693
        scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
Aditya Oke's avatar
Aditya Oke committed
1694
1695
1696

        tol = 1e-3 if dtype is torch.half else 1e-5
        torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)
Abhijit Deo's avatar
Abhijit Deo committed
1697

1698
    # Raise ValueError for anonymous reduction mode
1699
    @pytest.mark.parametrize("device", cpu_and_cuda())
1700
1701
1702
1703
1704
1705
1706
1707
1708
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_reduction_mode(self, device, dtype, reduction="xyz"):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        torch.random.manual_seed(0)
        inputs, targets = self._generate_diverse_input_target_pair(device=device, dtype=dtype)
        with pytest.raises(ValueError, match="Invalid"):
            ops.sigmoid_focal_loss(inputs, targets, 0.25, 2, reduction)

Abhijit Deo's avatar
Abhijit Deo committed
1709

1710
1711
class TestMasksToBoxes:
    def test_masks_box(self):
Aditya Oke's avatar
Aditya Oke committed
1712
        def masks_box_check(masks, expected, atol=1e-4):
1713
1714
            out = ops.masks_to_boxes(masks)
            assert out.dtype == torch.float
Aditya Oke's avatar
Aditya Oke committed
1715
            torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=True, atol=atol)
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731

        # Check for int type boxes.
        def _get_image():
            assets_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
            mask_path = os.path.join(assets_directory, "masks.tiff")
            image = Image.open(mask_path)
            return image

        def _create_masks(image, masks):
            for index in range(image.n_frames):
                image.seek(index)
                frame = np.array(image)
                masks[index] = torch.tensor(frame)

            return masks

1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
        expected = torch.tensor(
            [
                [127, 2, 165, 40],
                [2, 50, 44, 92],
                [56, 63, 98, 100],
                [139, 68, 175, 104],
                [160, 112, 198, 145],
                [49, 138, 99, 182],
                [108, 148, 152, 213],
            ],
            dtype=torch.float,
        )
1744
1745
1746
1747
1748
1749
1750
1751

        image = _get_image()
        for dtype in [torch.float16, torch.float32, torch.float64]:
            masks = torch.zeros((image.n_frames, image.height, image.width), dtype=dtype)
            masks = _create_masks(image, masks)
            masks_box_check(masks, expected)


1752
class TestStochasticDepth:
1753
    @pytest.mark.parametrize("seed", range(10))
1754
1755
    @pytest.mark.parametrize("p", [0.2, 0.5, 0.8])
    @pytest.mark.parametrize("mode", ["batch", "row"])
1756
1757
    def test_stochastic_depth_random(self, seed, mode, p):
        torch.manual_seed(seed)
1758
1759
1760
        stats = pytest.importorskip("scipy.stats")
        batch_size = 5
        x = torch.ones(size=(batch_size, 3, 4, 4))
1761
        layer = ops.StochasticDepth(p=p, mode=mode)
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
        layer.__repr__()

        trials = 250
        num_samples = 0
        counts = 0
        for _ in range(trials):
            out = layer(x)
            non_zero_count = out.sum(dim=(1, 2, 3)).nonzero().size(0)
            if mode == "batch":
                if non_zero_count == 0:
                    counts += 1
                num_samples += 1
            elif mode == "row":
                counts += batch_size - non_zero_count
                num_samples += batch_size

1778
        p_value = stats.binomtest(counts, num_samples, p=p).pvalue
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
        assert p_value > 0.01

    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("p", (0, 1))
    @pytest.mark.parametrize("mode", ["batch", "row"])
    def test_stochastic_depth(self, seed, mode, p):
        torch.manual_seed(seed)
        batch_size = 5
        x = torch.ones(size=(batch_size, 3, 4, 4))
        layer = ops.StochasticDepth(p=p, mode=mode)

        out = layer(x)
        if p == 0:
            assert out.equal(x)
        elif p == 1:
            assert out.equal(torch.zeros_like(x))
1795

1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
    def make_obj(self, p, mode, wrap=False):
        obj = ops.StochasticDepth(p, mode)
        return StochasticDepthWrapper(obj) if wrap else obj

    @pytest.mark.parametrize("p", (0, 1))
    @pytest.mark.parametrize("mode", ["batch", "row"])
    def test_is_leaf_node(self, p, mode):
        op_obj = self.make_obj(p, mode, wrap=True)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

1810

1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
class TestUtils:
    @pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
    def test_split_normalization_params(self, norm_layer):
        model = models.mobilenet_v3_large(norm_layer=norm_layer)
        params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])

        assert len(params[0]) == 92
        assert len(params[1]) == 82


1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
class TestDropBlock:
    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("dim", [2, 3])
    @pytest.mark.parametrize("p", [0, 0.5])
    @pytest.mark.parametrize("block_size", [5, 11])
    @pytest.mark.parametrize("inplace", [True, False])
    def test_drop_block(self, seed, dim, p, block_size, inplace):
        torch.manual_seed(seed)
        batch_size = 5
        channels = 3
        height = 11
        width = height
        depth = height
        if dim == 2:
            x = torch.ones(size=(batch_size, channels, height, width))
            layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
            feature_size = height * width
        elif dim == 3:
            x = torch.ones(size=(batch_size, channels, depth, height, width))
            layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
            feature_size = depth * height * width
        layer.__repr__()

        out = layer(x)
        if p == 0:
            assert out.equal(x)
        if block_size == height:
            for b, c in product(range(batch_size), range(channels)):
                assert out[b, c].count_nonzero() in (0, feature_size)

    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("dim", [2, 3])
    @pytest.mark.parametrize("p", [0.1, 0.2])
    @pytest.mark.parametrize("block_size", [3])
    @pytest.mark.parametrize("inplace", [False])
    def test_drop_block_random(self, seed, dim, p, block_size, inplace):
        torch.manual_seed(seed)
        batch_size = 5
        channels = 3
        height = 11
        width = height
        depth = height
        if dim == 2:
            x = torch.ones(size=(batch_size, channels, height, width))
            layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
        elif dim == 3:
            x = torch.ones(size=(batch_size, channels, depth, height, width))
            layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)

        trials = 250
        num_samples = 0
        counts = 0
        cell_numel = torch.tensor(x.shape).prod()
        for _ in range(trials):
            with torch.no_grad():
                out = layer(x)
            non_zero_count = out.nonzero().size(0)
            counts += cell_numel - non_zero_count
            num_samples += cell_numel

        assert abs(p - counts / num_samples) / p < 0.15

    def make_obj(self, dim, p, block_size, inplace, wrap=False):
        if dim == 2:
            obj = ops.DropBlock2d(p, block_size, inplace)
        elif dim == 3:
            obj = ops.DropBlock3d(p, block_size, inplace)
        return DropBlockWrapper(obj) if wrap else obj

    @pytest.mark.parametrize("dim", (2, 3))
    @pytest.mark.parametrize("p", [0, 1])
    @pytest.mark.parametrize("block_size", [5, 7])
    @pytest.mark.parametrize("inplace", [True, False])
    def test_is_leaf_node(self, dim, p, block_size, inplace):
        op_obj = self.make_obj(dim, p, block_size, inplace, wrap=True)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs


1903
if __name__ == "__main__":
1904
    pytest.main([__file__])