test_ops.py 79.6 KB
Newer Older
1
import math
2
import os
3
from abc import ABC, abstractmethod
4
from functools import lru_cache
5
from itertools import product
6
from typing import Callable, List, Tuple
7

8
import numpy as np
9
import pytest
10
import torch
11
import torch.fx
12
import torch.nn.functional as F
13
import torch.testing._internal.optests as optests
14
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
15
from PIL import Image
16
from torch import nn, Tensor
17
from torch._dynamo.utils import is_compile_supported
18
from torch.autograd import gradcheck
19
from torch.nn.modules.utils import _pair
20
from torchvision import models, ops
21
22
23
from torchvision.models.feature_extraction import get_graph_node_names


24
25
26
27
28
29
30
31
OPTESTS = [
    "test_schema",
    "test_autograd_registration",
    "test_faketensor",
    "test_aot_dispatch_dynamic",
]


32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Context manager for setting deterministic flag and automatically
# resetting it to its original value
class DeterministicGuard:
    def __init__(self, deterministic, *, warn_only=False):
        self.deterministic = deterministic
        self.warn_only = warn_only

    def __enter__(self):
        self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
        self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
        torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)

    def __exit__(self, exception_type, exception_value, traceback):
        torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore)


48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class RoIOpTesterModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 2

    def forward(self, a, b):
        self.layer(a, b)


class MultiScaleRoIAlignModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 3

    def forward(self, a, b, c):
        self.layer(a, b, c)


class DeformConvModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 3

    def forward(self, a, b, c):
        self.layer(a, b, c)


class StochasticDepthWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 1

    def forward(self, a):
        self.layer(a)
86
87


88
89
90
91
92
93
94
95
96
97
class DropBlockWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 1

    def forward(self, a):
        self.layer(a)


98
99
100
101
102
103
104
105
106
class PoolWrapper(nn.Module):
    def __init__(self, pool: nn.Module):
        super().__init__()
        self.pool = pool

    def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:
        return self.pool(imgs, boxes)


107
108
class RoIOpTester(ABC):
    dtype = torch.float64
109
110
    mps_dtype = torch.float32
    mps_backward_atol = 2e-2
111

112
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
113
    @pytest.mark.parametrize("contiguous", (True, False))
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    @pytest.mark.parametrize(
        "x_dtype",
        (
            torch.float16,
            torch.float32,
            torch.float64,
        ),
        ids=str,
    )
    def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
        if device == "mps" and x_dtype is torch.float64:
            pytest.skip("MPS does not support float64")

        rois_dtype = x_dtype if rois_dtype is None else rois_dtype

        tol = 1e-5
        if x_dtype is torch.half:
            if device == "mps":
                tol = 5e-3
            else:
                tol = 4e-3
135
136
        elif x_dtype == torch.bfloat16:
            tol = 5e-3
137

138
        pool_size = 5
139
        # n_channels % (pool_size ** 2) == 0 required for PS operations.
140
        n_channels = 2 * (pool_size**2)
141
        x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
142
143
        if not contiguous:
            x = x.permute(0, 1, 3, 2)
144
145
146
147
148
        rois = torch.tensor(
            [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],  # format is (xyxy)
            dtype=rois_dtype,
            device=device,
        )
149

150
        pool_h, pool_w = pool_size, pool_size
151
152
        with DeterministicGuard(deterministic):
            y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
153
        # the following should be true whether we're running an autocast test or not.
154
        assert y.dtype == x.dtype
155
        gt_y = self.expected_fn(
156
            x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs
157
        )
158

159
        torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
160

161
    @pytest.mark.parametrize("device", cpu_and_cuda())
162
163
164
165
166
167
168
169
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

170
    @pytest.mark.parametrize("device", cpu_and_cuda())
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.float):
        op_obj = self.make_obj().to(device=device)
        graph_module = torch.fx.symbolic_trace(op_obj)
        pool_size = 5
        n_channels = 2 * (pool_size**2)
        x = torch.rand(2, n_channels, 5, 5, dtype=x_dtype, device=device)
        rois = torch.tensor(
            [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],  # format is (xyxy)
            dtype=rois_dtype,
            device=device,
        )
        output_gt = op_obj(x, rois)
        assert output_gt.dtype == x.dtype
        output_fx = graph_module(x, rois)
        assert output_fx.dtype == x.dtype
        tol = 1e-5
        torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)

189
    @pytest.mark.parametrize("seed", range(10))
190
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
191
    @pytest.mark.parametrize("contiguous", (True, False))
192
    def test_backward(self, seed, device, contiguous, deterministic=False):
193
194
195
        atol = self.mps_backward_atol if device == "mps" else 1e-05
        dtype = self.mps_dtype if device == "mps" else self.dtype

196
        torch.random.manual_seed(seed)
197
        pool_size = 2
198
        x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True)
199
200
        if not contiguous:
            x = x.permute(0, 1, 3, 2)
201
        rois = torch.tensor(
202
            [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device  # format is (xyxy)
203
        )
204

205
206
        def func(z):
            return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
207

208
        script_func = self.get_script_fn(rois, pool_size)
209

210
        with DeterministicGuard(deterministic):
211
212
213
214
215
216
217
218
219
220
221
            gradcheck(func, (x,), atol=atol)

        gradcheck(script_func, (x,), atol=atol)

    @needs_mps
    def test_mps_error_inputs(self):
        pool_size = 2
        x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True)
        rois = torch.tensor(
            [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps"  # format is (xyxy)
        )
222

223
224
225
226
227
228
229
        def func(z):
            return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)

        with pytest.raises(
            RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs."
        ):
            gradcheck(func, (x,))
230

231
    @needs_cuda
232
233
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
234
235
236
    def test_autocast(self, x_dtype, rois_dtype):
        with torch.cuda.amp.autocast():
            self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
237
238
239

    def _helper_boxes_shape(self, func):
        # test boxes as Tensor[N, 5]
240
        with pytest.raises(AssertionError):
241
242
243
244
245
            a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
            boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
            func(a, boxes, output_size=(2, 2))

        # test boxes as List[Tensor[N, 4]]
246
        with pytest.raises(AssertionError):
247
248
249
250
            a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
            boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
            ops.roi_pool(a, [boxes], output_size=(2, 2))

251
252
253
254
255
256
257
258
    def _helper_jit_boxes_list(self, model):
        x = torch.rand(2, 1, 10, 10)
        roi = torch.tensor([[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], dtype=torch.float).t()
        rois = [roi, roi]
        scriped = torch.jit.script(model)
        y = scriped(x, rois)
        assert y.shape == (10, 1, 3, 3)

259
    @abstractmethod
260
261
    def fn(*args, **kwargs):
        pass
262

263
264
265
266
    @abstractmethod
    def make_obj(*args, **kwargs):
        pass

267
    @abstractmethod
268
269
    def get_script_fn(*args, **kwargs):
        pass
270

271
    @abstractmethod
272
273
    def expected_fn(*args, **kwargs):
        pass
274

275

276
class TestRoiPool(RoIOpTester):
277
278
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
279

280
281
282
283
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
        obj = ops.RoIPool((pool_h, pool_w), spatial_scale)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

284
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
285
286
        scriped = torch.jit.script(ops.roi_pool)
        return lambda x: scriped(x, rois, pool_size)
287

288
289
290
    def expected_fn(
        self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
    ):
291
292
        if device is None:
            device = torch.device("cpu")
293

294
295
        n_channels = x.size(1)
        y = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
296

297
298
        def get_slice(k, block):
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))
299

300
301
302
        for roi_idx, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
303
            roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
304

305
306
307
            roi_h, roi_w = roi_x.shape[-2:]
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w
308

309
310
311
312
313
314
            for i in range(0, pool_h):
                for j in range(0, pool_w):
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
                    if bin_x.numel() > 0:
                        y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0]
        return y
315

316
    def test_boxes_shape(self):
317
318
        self._helper_boxes_shape(ops.roi_pool)

319
320
321
322
    def test_jit_boxes_list(self):
        model = PoolWrapper(ops.RoIPool(output_size=[3, 3], spatial_scale=1.0))
        self._helper_jit_boxes_list(model)

323

324
class TestPSRoIPool(RoIOpTester):
325
326
    mps_backward_atol = 5e-2

327
328
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
329

330
331
332
333
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
        obj = ops.PSRoIPool((pool_h, pool_w), spatial_scale)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

334
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
335
336
        scriped = torch.jit.script(ops.ps_roi_pool)
        return lambda x: scriped(x, rois, pool_size)
337

338
339
340
    def expected_fn(
        self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
    ):
341
342
343
        if device is None:
            device = torch.device("cpu")
        n_input_channels = x.size(1)
344
        assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
345
346
347
348
349
350
351
352
353
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
        y = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)

        def get_slice(k, block):
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))

        for roi_idx, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
354
            roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369

            roi_height = max(i_end - i_begin, 1)
            roi_width = max(j_end - j_begin, 1)
            bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w)

            for i in range(0, pool_h):
                for j in range(0, pool_w):
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
                    if bin_x.numel() > 0:
                        area = bin_x.size(-2) * bin_x.size(-1)
                        for c_out in range(0, n_output_channels):
                            c_in = c_out * (pool_h * pool_w) + pool_w * i + j
                            t = torch.sum(bin_x[c_in, :, :])
                            y[roi_idx, c_out, i, j] = t / area
        return y
370

371
    def test_boxes_shape(self):
372
373
        self._helper_boxes_shape(ops.ps_roi_pool)

374

375
376
def bilinear_interpolate(data, y, x, snap_border=False):
    height, width = data.shape
377

378
379
380
381
382
    if snap_border:
        if -1 < y <= 0:
            y = 0
        elif height - 1 <= y < height:
            y = height - 1
383

384
385
386
387
        if -1 < x <= 0:
            x = 0
        elif width - 1 <= x < width:
            x = width - 1
388

389
390
391
392
    y_low = int(math.floor(y))
    x_low = int(math.floor(x))
    y_high = y_low + 1
    x_high = x_low + 1
393

394
395
    wy_h = y - y_low
    wx_h = x - x_low
396
    wy_l = 1 - wy_h
397
    wx_l = 1 - wx_h
398

399
    val = 0
400
401
402
403
    for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
        for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
            if 0 <= yp < height and 0 <= xp < width:
                val += wx * wy * data[yp, xp]
404
    return val
405
406


407
class TestRoIAlign(RoIOpTester):
408
409
    mps_backward_atol = 6e-2

AhnDW's avatar
AhnDW committed
410
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
411
412
413
        return ops.RoIAlign(
            (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
        )(x, rois)
414

415
416
417
418
419
420
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, aligned=False, wrap=False):
        obj = ops.RoIAlign(
            (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
        )
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

421
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
422
423
        scriped = torch.jit.script(ops.roi_align)
        return lambda x: scriped(x, rois, pool_size)
424

425
426
427
428
429
430
431
432
433
434
435
436
    def expected_fn(
        self,
        in_data,
        rois,
        pool_h,
        pool_w,
        spatial_scale=1,
        sampling_ratio=-1,
        aligned=False,
        device=None,
        dtype=torch.float64,
    ):
437
438
        if device is None:
            device = torch.device("cpu")
439
440
441
        n_channels = in_data.size(1)
        out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)

442
        offset = 0.5 if aligned else 0.0
AhnDW's avatar
AhnDW committed
443

444
445
        for r, roi in enumerate(rois):
            batch_idx = int(roi[0])
AhnDW's avatar
AhnDW committed
446
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - offset for x in roi[1:])
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465

            roi_h = i_end - i_begin
            roi_w = j_end - j_begin
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w

            for i in range(0, pool_h):
                start_h = i_begin + i * bin_h
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
                for j in range(0, pool_w):
                    start_w = j_begin + j * bin_w
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))

                    for channel in range(0, n_channels):
                        val = 0
                        for iy in range(0, grid_h):
                            y = start_h + (iy + 0.5) * bin_h / grid_h
                            for ix in range(0, grid_w):
                                x = start_w + (ix + 0.5) * bin_w / grid_w
466
                                val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True)
467
468
469
                        val /= grid_h * grid_w

                        out_data[r, channel, i, j] = val
470
471
        return out_data

472
    def test_boxes_shape(self):
473
474
        self._helper_boxes_shape(ops.roi_align)

475
    @pytest.mark.parametrize("aligned", (True, False))
476
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
477
    @pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64))  # , ids=str)
478
    @pytest.mark.parametrize("contiguous", (True, False))
479
    @pytest.mark.parametrize("deterministic", (True, False))
480
    @pytest.mark.opcheck_only_one()
481
    def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None):
482
483
        if deterministic and device == "cpu":
            pytest.skip("cpu is always deterministic, don't retest")
484
        super().test_forward(
485
486
487
488
489
490
            device=device,
            contiguous=contiguous,
            deterministic=deterministic,
            x_dtype=x_dtype,
            rois_dtype=rois_dtype,
            aligned=aligned,
491
        )
492

493
    @needs_cuda
494
    @pytest.mark.parametrize("aligned", (True, False))
495
    @pytest.mark.parametrize("deterministic", (True, False))
496
497
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
498
    @pytest.mark.opcheck_only_one()
499
    def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
500
        with torch.cuda.amp.autocast():
501
            self.test_forward(
502
503
504
505
506
507
                torch.device("cuda"),
                contiguous=False,
                deterministic=deterministic,
                aligned=aligned,
                x_dtype=x_dtype,
                rois_dtype=rois_dtype,
508
            )
509

510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
    @pytest.mark.parametrize("aligned", (True, False))
    @pytest.mark.parametrize("deterministic", (True, False))
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.bfloat16))
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.bfloat16))
    def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
        with torch.cpu.amp.autocast():
            self.test_forward(
                torch.device("cpu"),
                contiguous=False,
                deterministic=deterministic,
                aligned=aligned,
                x_dtype=x_dtype,
                rois_dtype=rois_dtype,
            )

525
    @pytest.mark.parametrize("seed", range(10))
526
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
527
528
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("deterministic", (True, False))
529
    @pytest.mark.opcheck_only_one()
530
531
532
    def test_backward(self, seed, device, contiguous, deterministic):
        if deterministic and device == "cpu":
            pytest.skip("cpu is always deterministic, don't retest")
533
534
535
536
        if deterministic and device == "mps":
            pytest.skip("no deterministic implementation for mps")
        if deterministic and not is_compile_supported(device):
            pytest.skip("deterministic implementation only if torch.compile supported")
537
538
        super().test_backward(seed, device, contiguous, deterministic)

539
540
541
542
543
544
    def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
        rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
        rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,))  # set batch index
        rois[:, 3:] += rois[:, 1:3]  # make sure boxes aren't degenerate
        return rois

545
546
547
    @pytest.mark.parametrize("aligned", (True, False))
    @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 10), (0.1, 50)))
    @pytest.mark.parametrize("qdtype", (torch.qint8, torch.quint8, torch.qint32))
548
    @pytest.mark.opcheck_only_one()
549
    def test_qroialign(self, aligned, scale, zero_point, qdtype):
550
551
552
553
554
555
556
        """Make sure quantized version of RoIAlign is close to float version"""
        pool_size = 5
        img_size = 10
        n_channels = 2
        num_imgs = 1
        dtype = torch.float

557
558
559
560
561
562
563
        x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype)
        qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qdtype)

        rois = self._make_rois(img_size, num_imgs, dtype)
        qrois = torch.quantize_per_tensor(rois, scale=scale, zero_point=zero_point, dtype=qdtype)

        x, rois = qx.dequantize(), qrois.dequantize()  # we want to pass the same inputs
564

565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
        y = ops.roi_align(
            x,
            rois,
            output_size=pool_size,
            spatial_scale=1,
            sampling_ratio=-1,
            aligned=aligned,
        )
        qy = ops.roi_align(
            qx,
            qrois,
            output_size=pool_size,
            spatial_scale=1,
            sampling_ratio=-1,
            aligned=aligned,
        )

        # The output qy is itself a quantized tensor and there might have been a loss of info when it was
        # quantized. For a fair comparison we need to quantize y as well
        quantized_float_y = torch.quantize_per_tensor(y, scale=scale, zero_point=zero_point, dtype=qdtype)

        try:
            # Ideally, we would assert this, which passes with (scale, zero) == (1, 0)
            assert (qy == quantized_float_y).all()
        except AssertionError:
            # But because the computation aren't exactly the same between the 2 RoIAlign procedures, some
            # rounding error may lead to a difference of 2 in the output.
            # For example with (scale, zero) = (2, 10), 45.00000... will be quantized to 44
            # but 45.00000001 will be rounded to 46. We make sure below that:
            # - such discrepancies between qy and quantized_float_y are very rare (less then 5%)
            # - any difference between qy and quantized_float_y is == scale
            diff_idx = torch.where(qy != quantized_float_y)
            num_diff = diff_idx[0].numel()
598
            assert num_diff / qy.numel() < 0.05
599
600
601
602
603
604
605

            abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize())
            t_scale = torch.full_like(abs_diff, fill_value=scale)
            torch.testing.assert_close(abs_diff, t_scale, rtol=1e-5, atol=1e-5)

    def test_qroi_align_multiple_images(self):
        dtype = torch.float
606
607
        x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype)
        qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8)
608
        rois = self._make_rois(img_size=10, num_imgs=2, dtype=dtype, num_rois=10)
609
        qrois = torch.quantize_per_tensor(rois, scale=1, zero_point=0, dtype=torch.qint8)
610
611
        with pytest.raises(RuntimeError, match="Only one image per batch is allowed"):
            ops.roi_align(qx, qrois, output_size=5)
612

613
614
615
616
    def test_jit_boxes_list(self):
        model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1))
        self._helper_jit_boxes_list(model)

617

618
class TestPSRoIAlign(RoIOpTester):
619
620
    mps_backward_atol = 5e-2

621
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
622
        return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
623

624
625
626
627
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, wrap=False):
        obj = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

628
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
629
630
        scriped = torch.jit.script(ops.ps_roi_align)
        return lambda x: scriped(x, rois, pool_size)
631

632
633
634
    def expected_fn(
        self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, sampling_ratio=-1, dtype=torch.float64
    ):
635
636
        if device is None:
            device = torch.device("cpu")
637
        n_input_channels = in_data.size(1)
638
        assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
        out_data = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)

        for r, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - 0.5 for x in roi[1:])

            roi_h = i_end - i_begin
            roi_w = j_end - j_begin
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w

            for i in range(0, pool_h):
                start_h = i_begin + i * bin_h
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
                for j in range(0, pool_w):
                    start_w = j_begin + j * bin_w
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
                    for c_out in range(0, n_output_channels):
                        c_in = c_out * (pool_h * pool_w) + pool_w * i + j

                        val = 0
                        for iy in range(0, grid_h):
                            y = start_h + (iy + 0.5) * bin_h / grid_h
                            for ix in range(0, grid_w):
                                x = start_w + (ix + 0.5) * bin_w / grid_w
665
                                val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True)
666
667
668
669
                        val /= grid_h * grid_w

                        out_data[r, c_out, i, j] = val
        return out_data
670

671
    def test_boxes_shape(self):
672
673
        self._helper_boxes_shape(ops.ps_roi_align)

674

675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
@pytest.mark.parametrize(
    "op",
    (
        torch.ops.torchvision.roi_pool,
        torch.ops.torchvision.ps_roi_pool,
        torch.ops.torchvision.roi_align,
        torch.ops.torchvision.ps_roi_align,
    ),
)
@pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64))
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("requires_grad", (True, False))
def test_roi_opcheck(op, dtype, device, requires_grad):
    # This manually calls opcheck() on the roi ops. We do that instead of
    # relying on opcheck.generate_opcheck_tests() as e.g. done for nms, because
    # pytest and generate_opcheck_tests() don't interact very well when it comes
    # to skipping tests - and these ops need to skip the MPS tests since MPS we
    # don't support dynamic shapes yet for MPS.
    rois = torch.tensor(
        [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],
        dtype=dtype,
        device=device,
        requires_grad=requires_grad,
    )
    pool_size = 5
    num_channels = 2 * (pool_size**2)
    x = torch.rand(2, num_channels, 10, 10, dtype=dtype, device=device)

    kwargs = dict(rois=rois, spatial_scale=1, pooled_height=pool_size, pooled_width=pool_size)
    if op in (torch.ops.torchvision.roi_align, torch.ops.torchvision.ps_roi_align):
        kwargs["sampling_ratio"] = -1
    if op is torch.ops.torchvision.roi_align:
        kwargs["aligned"] = True

    optests.opcheck(op, args=(x,), kwargs=kwargs)


712
class TestMultiScaleRoIAlign:
713
714
715
716
717
718
    def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False):
        if fmap_names is None:
            fmap_names = ["0"]
        obj = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)
        return MultiScaleRoIAlignModuleWrapper(obj) if wrap else obj

719
    def test_msroialign_repr(self):
720
        fmap_names = ["0"]
721
722
723
        output_size = (7, 7)
        sampling_ratio = 2
        # Pass mock feature map names
724
        t = self.make_obj(fmap_names, output_size, sampling_ratio, wrap=False)
725
726

        # Check integrity of object __repr__ attribute
727
728
729
730
        expected_string = (
            f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, "
            f"sampling_ratio={sampling_ratio})"
        )
731
        assert repr(t) == expected_string
732

733
    @pytest.mark.parametrize("device", cpu_and_cuda())
734
735
736
737
738
739
740
741
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

742

743
744
class TestNMS:
    def _reference_nms(self, boxes, scores, iou_threshold):
745
746
        """
        Args:
747
748
749
            boxes: boxes in corner-form
            scores: probabilities
            iou_threshold: intersection over union threshold
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
        Returns:
             picked: a list of indexes of the kept boxes
        """
        picked = []
        _, indexes = scores.sort(descending=True)
        while len(indexes) > 0:
            current = indexes[0]
            picked.append(current.item())
            if len(indexes) == 1:
                break
            current_box = boxes[current, :]
            indexes = indexes[1:]
            rest_boxes = boxes[indexes, :]
            iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1)
            indexes = indexes[iou <= iou_threshold]

        return torch.as_tensor(picked)

768
769
770
771
772
    def _create_tensors_with_iou(self, N, iou_thresh):
        # force last box to have a pre-defined iou with the first box
        # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
        # then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
        # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
773
774
775
        # Adjust the threshold upward a bit with the intent of creating
        # at least one box that exceeds (barely) the threshold and so
        # should be suppressed.
776
        boxes = torch.rand(N, 4) * 100
777
778
779
        boxes[:, 2:] += boxes[:, :2]
        boxes[-1, :] = boxes[0, :]
        x0, y0, x1, y1 = boxes[-1].tolist()
780
        iou_thresh += 1e-5
781
        boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
782
783
784
        scores = torch.rand(N)
        return boxes, scores

785
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
786
    @pytest.mark.parametrize("seed", range(10))
787
    @pytest.mark.opcheck_only_one()
788
789
    def test_nms_ref(self, iou, seed):
        torch.random.manual_seed(seed)
790
        err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
791
792
793
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        keep_ref = self._reference_nms(boxes, scores, iou)
        keep = ops.nms(boxes, scores, iou)
794
        torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou))
795
796
797
798
799
800
801
802
803
804
805

    def test_nms_input_errors(self):
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(4), torch.rand(3), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 5), torch.rand(3), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 4), torch.rand(4), 0.5)

806
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
807
    @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10)))
808
    @pytest.mark.opcheck_only_one()
809
    def test_qnms(self, iou, scale, zero_point):
810
        # Note: we compare qnms vs nms instead of qnms vs reference implementation.
811
        # This is because with the int conversion, the trick used in _create_tensors_with_iou
812
        # doesn't really work (in fact, nms vs reference implem will also fail with ints)
813
        err_msg = "NMS and QNMS give different results for IoU={}"
814
        boxes, scores = self._create_tensors_with_iou(1000, iou)
815
        scores *= 100  # otherwise most scores would be 0 or 1 after int conversion
816

817
818
        qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point, dtype=torch.quint8)
        qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point, dtype=torch.quint8)
819

820
821
        boxes = qboxes.dequantize()
        scores = qscores.dequantize()
822

823
824
        keep = ops.nms(boxes, scores, iou)
        qkeep = ops.nms(qboxes, qscores, iou)
825

826
        torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
827

828
829
830
831
832
833
834
    @pytest.mark.parametrize(
        "device",
        (
            pytest.param("cuda", marks=pytest.mark.needs_cuda),
            pytest.param("mps", marks=pytest.mark.needs_mps),
        ),
    )
835
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
836
    @pytest.mark.opcheck_only_one()
837
838
    def test_nms_gpu(self, iou, device, dtype=torch.float64):
        dtype = torch.float32 if device == "mps" else dtype
839
        tol = 1e-3 if dtype is torch.half else 1e-5
840
        err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
841

842
843
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        r_cpu = ops.nms(boxes, scores, iou)
844
        r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)
845

846
        is_eq = torch.allclose(r_cpu, r_gpu.cpu())
847
848
849
        if not is_eq:
            # if the indices are not the same, ensure that it's because the scores
            # are duplicate
850
            is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
851
852
853
        assert is_eq, err_msg.format(iou)

    @needs_cuda
854
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
855
    @pytest.mark.parametrize("dtype", (torch.float, torch.half))
856
    @pytest.mark.opcheck_only_one()
857
858
    def test_autocast(self, iou, dtype):
        with torch.cuda.amp.autocast():
859
860
            self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")

861
862
863
864
865
866
867
868
869
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
    @pytest.mark.parametrize("dtype", (torch.float, torch.bfloat16))
    def test_autocast_cpu(self, iou, dtype):
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        with torch.cpu.amp.autocast():
            keep_ref_float = ops.nms(boxes.to(dtype).float(), scores.to(dtype).float(), iou)
            keep_dtype = ops.nms(boxes.to(dtype), scores.to(dtype), iou)
        torch.testing.assert_close(keep_ref_float, keep_dtype)

870
871
872
873
874
875
876
    @pytest.mark.parametrize(
        "device",
        (
            pytest.param("cuda", marks=pytest.mark.needs_cuda),
            pytest.param("mps", marks=pytest.mark.needs_mps),
        ),
    )
877
    @pytest.mark.opcheck_only_one()
878
    def test_nms_float16(self, device):
879
880
881
882
883
884
        boxes = torch.tensor(
            [
                [285.3538, 185.5758, 1193.5110, 851.4551],
                [285.1472, 188.7374, 1192.4984, 851.0669],
                [279.2440, 197.9812, 1189.4746, 849.2019],
            ]
885
886
        ).to(device)
        scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device)
887
888
889
890

        iou_thres = 0.2
        keep32 = ops.nms(boxes, scores, iou_thres)
        keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
891
        assert_equal(keep32, keep16)
892

893
    @pytest.mark.parametrize("seed", range(10))
894
    @pytest.mark.opcheck_only_one()
895
    def test_batched_nms_implementations(self, seed):
896
        """Make sure that both implementations of batched_nms yield identical results"""
897
        torch.random.manual_seed(seed)
898
899

        num_boxes = 1000
900
        iou_threshold = 0.9
901
902
903
904
905
906
907
908
909
910

        boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
        assert max(boxes[:, 0]) < min(boxes[:, 2])  # x1 < x2
        assert max(boxes[:, 1]) < min(boxes[:, 3])  # y1 < y2

        scores = torch.rand(num_boxes)
        idxs = torch.randint(0, 4, size=(num_boxes,))
        keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
        keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)

911
912
913
        torch.testing.assert_close(
            keep_vanilla, keep_trick, msg="The vanilla and the trick implementation yield different nms outputs."
        )
914
915
916

        # Also make sure an empty tensor is returned if boxes is empty
        empty = torch.empty((0,), dtype=torch.int64)
917
        torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None))
918

919

920
921
922
923
924
925
926
927
928
optests.generate_opcheck_tests(
    testcase=TestNMS,
    namespaces=["torchvision"],
    failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
    additional_decorators=[],
    test_utils=OPTESTS,
)


929
930
931
class TestDeformConv:
    dtype = torch.float64

932
    def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
        stride_h, stride_w = _pair(stride)
        pad_h, pad_w = _pair(padding)
        dil_h, dil_w = _pair(dilation)
        weight_h, weight_w = weight.shape[-2:]

        n_batches, n_in_channels, in_h, in_w = x.shape
        n_out_channels = weight.shape[0]

        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

        n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w)
        in_c_per_offset_grp = n_in_channels // n_offset_grps

        n_weight_grps = n_in_channels // weight.shape[1]
        in_c_per_weight_grp = weight.shape[1]
        out_c_per_weight_grp = n_out_channels // n_weight_grps

        out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
        for b in range(n_batches):
            for c_out in range(n_out_channels):
                for i in range(out_h):
                    for j in range(out_w):
                        for di in range(weight_h):
                            for dj in range(weight_w):
                                for c in range(in_c_per_weight_grp):
                                    weight_grp = c_out // out_c_per_weight_grp
                                    c_in = weight_grp * in_c_per_weight_grp + c

                                    offset_grp = c_in // in_c_per_offset_grp
963
964
                                    mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj
                                    offset_idx = 2 * mask_idx
965
966
967
968

                                    pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
                                    pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]

969
970
971
972
                                    mask_value = 1.0
                                    if mask is not None:
                                        mask_value = mask[b, mask_idx, i, j]

973
974
975
976
977
                                    out[b, c_out, i, j] += (
                                        mask_value
                                        * weight[c_out, c, di, dj]
                                        * bilinear_interpolate(x[b, c_in, :, :], pi, pj)
                                    )
978
979
980
        out += bias.view(1, n_out_channels, 1, 1)
        return out

981
    @lru_cache(maxsize=None)
982
    def get_fn_args(self, device, contiguous, batch_sz, dtype):
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
        n_in_channels = 6
        n_out_channels = 2
        n_weight_grps = 2
        n_offset_grps = 3

        stride = (2, 1)
        pad = (1, 0)
        dilation = (2, 1)

        stride_h, stride_w = stride
        pad_h, pad_w = pad
        dil_h, dil_w = dilation
        weight_h, weight_w = (3, 2)
        in_h, in_w = (5, 4)

        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

1001
        x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True)
1002

1003
1004
1005
1006
1007
1008
1009
1010
1011
        offset = torch.randn(
            batch_sz,
            n_offset_grps * 2 * weight_h * weight_w,
            out_h,
            out_w,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
1012

1013
1014
1015
        mask = torch.randn(
            batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, device=device, dtype=dtype, requires_grad=True
        )
1016

1017
1018
1019
1020
1021
1022
1023
1024
1025
        weight = torch.randn(
            n_out_channels,
            n_in_channels // n_weight_grps,
            weight_h,
            weight_w,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
1026

1027
        bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True)
1028
1029
1030
1031

        if not contiguous:
            x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
            offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
1032
            mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
1033
1034
            weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)

1035
        return x, weight, offset, mask, bias, stride, pad, dilation
1036

1037
1038
1039
1040
1041
1042
    def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, wrap=False):
        obj = ops.DeformConv2d(
            in_channels, out_channels, kernel_size, stride=(2, 1), padding=(1, 0), dilation=(2, 1), groups=groups
        )
        return DeformConvModuleWrapper(obj) if wrap else obj

1043
    @pytest.mark.parametrize("device", cpu_and_cuda())
1044
1045
1046
1047
1048
1049
1050
1051
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

1052
    @pytest.mark.parametrize("device", cpu_and_cuda())
1053
1054
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("batch_sz", (0, 33))
1055
    @pytest.mark.opcheck_only_one()
1056
1057
    def test_forward(self, device, contiguous, batch_sz, dtype=None):
        dtype = dtype or self.dtype
1058
        x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
1059
1060
1061
1062
        in_channels = 6
        out_channels = 2
        kernel_size = (3, 2)
        groups = 2
Nicolas Hug's avatar
Nicolas Hug committed
1063
        tol = 2e-3 if dtype is torch.half else 1e-5
1064

1065
1066
1067
        layer = self.make_obj(in_channels, out_channels, kernel_size, groups, wrap=False).to(
            device=x.device, dtype=dtype
        )
1068
        res = layer(x, offset, mask)
1069
1070
1071

        weight = layer.weight.data
        bias = layer.bias.data
1072
1073
        expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)

1074
        torch.testing.assert_close(
1075
            res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1076
        )
1077
1078
1079
1080

        # no modulation test
        res = layer(x, offset)
        expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
1081

1082
        torch.testing.assert_close(
1083
            res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1084
        )
1085

1086
1087
1088
1089
1090
    def test_wrong_sizes(self):
        in_channels = 6
        out_channels = 2
        kernel_size = (3, 2)
        groups = 2
1091
1092
1093
1094
1095
1096
        x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(
            "cpu", contiguous=True, batch_sz=10, dtype=self.dtype
        )
        layer = ops.DeformConv2d(
            in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups
        )
1097
        with pytest.raises(RuntimeError, match="the shape of the offset"):
1098
            wrong_offset = torch.rand_like(offset[:, :2])
1099
            layer(x, wrong_offset)
1100

1101
        with pytest.raises(RuntimeError, match=r"mask.shape\[1\] is not valid"):
1102
            wrong_mask = torch.rand_like(mask[:, :2])
1103
            layer(x, offset, wrong_mask)
1104

1105
    @pytest.mark.parametrize("device", cpu_and_cuda())
1106
1107
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("batch_sz", (0, 33))
1108
    @pytest.mark.opcheck_only_one()
1109
    def test_backward(self, device, contiguous, batch_sz):
1110
1111
1112
        x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
            device, contiguous, batch_sz, self.dtype
        )
1113
1114

        def func(x_, offset_, mask_, weight_, bias_):
1115
1116
1117
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_
            )
1118

1119
        gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True)
1120
1121

        def func_no_mask(x_, offset_, weight_, bias_):
1122
1123
1124
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None
            )
1125

1126
        gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True)
1127
1128
1129
1130

        @torch.jit.script
        def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
            # type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=mask_
            )

        gradcheck(
            lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
            (x, offset, mask, weight, bias),
            nondet_tol=1e-5,
            fast_mode=True,
        )
1141
1142

        @torch.jit.script
1143
1144
        def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
            # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=None
            )

        gradcheck(
            lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
            (x, offset, weight, bias),
            nondet_tol=1e-5,
            fast_mode=True,
        )
1155

1156
    @needs_cuda
1157
    @pytest.mark.parametrize("contiguous", (True, False))
1158
    @pytest.mark.opcheck_only_one()
1159
    def test_compare_cpu_cuda_grads(self, contiguous):
1160
1161
1162
        # Test from https://github.com/pytorch/vision/issues/2598
        # Run on CUDA only

1163
1164
        # compare grads computed on CUDA with grads computed on CPU
        true_cpu_grads = None
1165

1166
1167
1168
1169
        init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
        img = torch.randn(8, 9, 1000, 110)
        offset = torch.rand(8, 2 * 3 * 3, 1000, 110)
        mask = torch.rand(8, 3 * 3, 1000, 110)
1170

1171
1172
1173
1174
1175
1176
1177
        if not contiguous:
            img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
            offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
            mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
            weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
        else:
            weight = init_weight
1178

1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
        for d in ["cpu", "cuda"]:
            out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
            out.mean().backward()
            if true_cpu_grads is None:
                true_cpu_grads = init_weight.grad
                assert true_cpu_grads is not None
            else:
                assert init_weight.grad is not None
                res_grads = init_weight.grad.to("cpu")
                torch.testing.assert_close(true_cpu_grads, res_grads)

    @needs_cuda
1191
1192
    @pytest.mark.parametrize("batch_sz", (0, 33))
    @pytest.mark.parametrize("dtype", (torch.float, torch.half))
1193
    @pytest.mark.opcheck_only_one()
1194
1195
1196
1197
    def test_autocast(self, batch_sz, dtype):
        with torch.cuda.amp.autocast():
            self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)

1198
1199
1200
1201
    def test_forward_scriptability(self):
        # Non-regression test for https://github.com/pytorch/vision/issues/4078
        torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))

1202

1203
1204
1205
1206
1207
1208
1209
1210
1211
optests.generate_opcheck_tests(
    testcase=TestDeformConv,
    namespaces=["torchvision"],
    failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
    additional_decorators=[],
    test_utils=OPTESTS,
)


1212
class TestFrozenBNT:
1213
1214
    def test_frozenbatchnorm2d_repr(self):
        num_features = 32
1215
1216
        eps = 1e-5
        t = ops.misc.FrozenBatchNorm2d(num_features, eps=eps)
1217
1218

        # Check integrity of object __repr__ attribute
1219
        expected_string = f"FrozenBatchNorm2d({num_features}, eps={eps})"
1220
        assert repr(t) == expected_string
1221

1222
1223
1224
    @pytest.mark.parametrize("seed", range(10))
    def test_frozenbatchnorm2d_eps(self, seed):
        torch.random.manual_seed(seed)
1225
1226
        sample_size = (4, 32, 28, 28)
        x = torch.rand(sample_size)
1227
1228
1229
1230
1231
1232
1233
        state_dict = dict(
            weight=torch.rand(sample_size[1]),
            bias=torch.rand(sample_size[1]),
            running_mean=torch.rand(sample_size[1]),
            running_var=torch.rand(sample_size[1]),
            num_batches_tracked=torch.tensor(100),
        )
1234

1235
        # Check that default eps is equal to the one of BN
1236
1237
        fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
        fbn.load_state_dict(state_dict, strict=False)
1238
        bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
1239
1240
        bn.load_state_dict(state_dict)
        # Difference is expected to fall in an acceptable range
1241
        torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
1242
1243
1244
1245
1246
1247

        # Check computation for eps > 0
        fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
        fbn.load_state_dict(state_dict, strict=False)
        bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
        bn.load_state_dict(state_dict)
1248
        torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
1249

1250

Aditya Oke's avatar
Aditya Oke committed
1251
class TestBoxConversionToRoi:
1252
1253
1254
    def _get_box_sequences():
        # Define here the argument type of `boxes` supported by region pooling operations
        box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float)
1255
1256
1257
1258
        box_list = [
            torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
            torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
        ]
1259
1260
1261
        box_tuple = tuple(box_list)
        return box_tensor, box_list, box_tuple

1262
    @pytest.mark.parametrize("box_sequence", _get_box_sequences())
1263
    def test_check_roi_boxes_shape(self, box_sequence):
1264
        # Ensure common sequences of tensors are supported
1265
        ops._utils.check_roi_boxes_shape(box_sequence)
1266

1267
    @pytest.mark.parametrize("box_sequence", _get_box_sequences())
1268
    def test_convert_boxes_to_roi_format(self, box_sequence):
1269
1270
        # Ensure common sequences of tensors yield the same result
        ref_tensor = None
1271
1272
1273
1274
        if ref_tensor is None:
            ref_tensor = box_sequence
        else:
            assert_equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence))
1275
1276


Aditya Oke's avatar
Aditya Oke committed
1277
class TestBoxConvert:
1278
    def test_bbox_same(self):
1279
1280
1281
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
1282

1283
        exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
1284

1285
1286
1287
1288
        assert exp_xyxy.size() == torch.Size([4, 4])
        assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy)
        assert_equal(ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh"), exp_xyxy)
        assert_equal(ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh"), exp_xyxy)
1289
1290
1291
1292

    def test_bbox_xyxy_xywh(self):
        # Simple test convert boxes to xywh and back. Make sure they are same.
        # box_tensor is in x1 y1 x2 y2 format.
1293
1294
1295
1296
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
        exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
1297

1298
        assert exp_xywh.size() == torch.Size([4, 4])
1299
        box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
1300
        assert_equal(box_xywh, exp_xywh)
1301
1302
1303

        # Reverse conversion
        box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy")
1304
        assert_equal(box_xyxy, box_tensor)
1305
1306

    def test_bbox_xyxy_cxcywh(self):
Aditya Oke's avatar
Aditya Oke committed
1307
        # Simple test convert boxes to cxcywh and back. Make sure they are same.
1308
        # box_tensor is in x1 y1 x2 y2 format.
1309
1310
1311
1312
1313
1314
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
        exp_cxcywh = torch.tensor(
            [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
        )
1315

1316
        assert exp_cxcywh.size() == torch.Size([4, 4])
1317
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
1318
        assert_equal(box_cxcywh, exp_cxcywh)
1319
1320
1321

        # Reverse conversion
        box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy")
1322
        assert_equal(box_xyxy, box_tensor)
1323
1324

    def test_bbox_xywh_cxcywh(self):
1325
1326
1327
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
        )
1328

1329
1330
1331
        exp_cxcywh = torch.tensor(
            [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
        )
1332

1333
        assert exp_cxcywh.size() == torch.Size([4, 4])
1334
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh")
1335
        assert_equal(box_cxcywh, exp_cxcywh)
1336
1337
1338

        # Reverse conversion
        box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
1339
        assert_equal(box_xywh, box_tensor)
1340

1341
1342
    @pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh"])
    @pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy"])
1343
    def test_bbox_invalid(self, inv_infmt, inv_outfmt):
1344
1345
1346
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
        )
1347

1348
1349
        with pytest.raises(ValueError):
            ops.box_convert(box_tensor, inv_infmt, inv_outfmt)
1350
1351

    def test_bbox_convert_jit(self):
1352
1353
1354
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
1355

1356
        scripted_fn = torch.jit.script(ops.box_convert)
1357

1358
        box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
1359
        scripted_xywh = scripted_fn(box_tensor, "xyxy", "xywh")
Aditya Oke's avatar
Aditya Oke committed
1360
        torch.testing.assert_close(scripted_xywh, box_xywh)
1361

1362
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
1363
        scripted_cxcywh = scripted_fn(box_tensor, "xyxy", "cxcywh")
Aditya Oke's avatar
Aditya Oke committed
1364
        torch.testing.assert_close(scripted_cxcywh, box_cxcywh)
1365
1366


Aditya Oke's avatar
Aditya Oke committed
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
class TestBoxArea:
    def area_check(self, box, expected, atol=1e-4):
        out = ops.box_area(box)
        torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)

    @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
    def test_int_boxes(self, dtype):
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
        expected = torch.tensor([10000, 0], dtype=torch.int32)
        self.area_check(box_tensor, expected)

    @pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
    def test_float_boxes(self, dtype):
        box_tensor = torch.tensor(FLOAT_BOXES, dtype=dtype)
        expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
        self.area_check(box_tensor, expected)

    def test_float16_box(self):
        box_tensor = torch.tensor(
            [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16
        )

        expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
        self.area_check(box_tensor, expected, atol=0.01)
1391

Aditya Oke's avatar
Aditya Oke committed
1392
1393
1394
1395
1396
1397
    def test_box_area_jit(self):
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
        expected = ops.box_area(box_tensor)
        scripted_fn = torch.jit.script(ops.box_area)
        scripted_area = scripted_fn(box_tensor)
        torch.testing.assert_close(scripted_area, expected)
1398

Aditya Oke's avatar
Aditya Oke committed
1399

1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
FLOAT_BOXES = [
    [285.3538, 185.5758, 1193.5110, 851.4551],
    [285.1472, 188.7374, 1192.4984, 851.0669],
    [279.2440, 197.9812, 1189.4746, 849.2019],
]


def gen_box(size, dtype=torch.float):
    xy1 = torch.rand((size, 2), dtype=dtype)
    xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
    return torch.cat([xy1, xy2], axis=-1)


Aditya Oke's avatar
Aditya Oke committed
1415
1416
class TestIouBase:
    @staticmethod
1417
    def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
1418
        for dtype in dtypes:
1419
1420
            actual_box1 = torch.tensor(actual_box1, dtype=dtype)
            actual_box2 = torch.tensor(actual_box2, dtype=dtype)
1421
            expected_box = torch.tensor(expected)
1422
            out = target_fn(actual_box1, actual_box2)
Aditya Oke's avatar
Aditya Oke committed
1423
            torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
Aditya Oke's avatar
Aditya Oke committed
1424

Aditya Oke's avatar
Aditya Oke committed
1425
    @staticmethod
1426
1427
    def _run_jit_test(target_fn: Callable, actual_box: List):
        box_tensor = torch.tensor(actual_box, dtype=torch.float)
Aditya Oke's avatar
Aditya Oke committed
1428
1429
1430
1431
        expected = target_fn(box_tensor, box_tensor)
        scripted_fn = torch.jit.script(target_fn)
        scripted_out = scripted_fn(box_tensor, box_tensor)
        torch.testing.assert_close(scripted_out, expected)
Aditya Oke's avatar
Aditya Oke committed
1432

1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
    @staticmethod
    def _cartesian_product(boxes1, boxes2, target_fn: Callable):
        N = boxes1.size(0)
        M = boxes2.size(0)
        result = torch.zeros((N, M))
        for i in range(N):
            for j in range(M):
                result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
        return result

    @staticmethod
    def _run_cartesian_test(target_fn: Callable):
        boxes1 = gen_box(5)
        boxes2 = gen_box(7)
        a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
        b = target_fn(boxes1, boxes2)
1449
        torch.testing.assert_close(a, b)
1450

1451

Aditya Oke's avatar
Aditya Oke committed
1452
class TestBoxIou(TestIouBase):
1453
    int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
Aditya Oke's avatar
Aditya Oke committed
1454
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
Aditya Oke's avatar
Aditya Oke committed
1455

Aditya Oke's avatar
Aditya Oke committed
1456
    @pytest.mark.parametrize(
1457
        "actual_box1, actual_box2, dtypes, atol, expected",
Aditya Oke's avatar
Aditya Oke committed
1458
        [
1459
1460
1461
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
Aditya Oke's avatar
Aditya Oke committed
1462
1463
        ],
    )
1464
1465
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)
Aditya Oke's avatar
Aditya Oke committed
1466

Aditya Oke's avatar
Aditya Oke committed
1467
1468
    def test_iou_jit(self):
        self._run_jit_test(ops.box_iou, INT_BOXES)
Aditya Oke's avatar
Aditya Oke committed
1469

1470
1471
1472
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.box_iou)

1473

Aditya Oke's avatar
Aditya Oke committed
1474
class TestGeneralizedBoxIou(TestIouBase):
1475
    int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
Aditya Oke's avatar
Aditya Oke committed
1476
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1477
1478

    @pytest.mark.parametrize(
1479
        "actual_box1, actual_box2, dtypes, atol, expected",
1480
        [
1481
1482
1483
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1484
1485
        ],
    )
1486
1487
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.generalized_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1488

Aditya Oke's avatar
Aditya Oke committed
1489
1490
    def test_iou_jit(self):
        self._run_jit_test(ops.generalized_box_iou, INT_BOXES)
1491

1492
1493
1494
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.generalized_box_iou)

1495

Aditya Oke's avatar
Aditya Oke committed
1496
class TestDistanceBoxIoU(TestIouBase):
1497
1498
1499
1500
1501
1502
    int_expected = [
        [1.0000, 0.1875, -0.4444],
        [0.1875, 1.0000, -0.5625],
        [-0.4444, -0.5625, 1.0000],
        [-0.0781, 0.1875, -0.6267],
    ]
Aditya Oke's avatar
Aditya Oke committed
1503
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1504

Aditya Oke's avatar
Aditya Oke committed
1505
    @pytest.mark.parametrize(
1506
        "actual_box1, actual_box2, dtypes, atol, expected",
Aditya Oke's avatar
Aditya Oke committed
1507
        [
1508
1509
1510
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
Aditya Oke's avatar
Aditya Oke committed
1511
1512
        ],
    )
1513
1514
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.distance_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1515

Aditya Oke's avatar
Aditya Oke committed
1516
1517
    def test_iou_jit(self):
        self._run_jit_test(ops.distance_box_iou, INT_BOXES)
1518

1519
1520
1521
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.distance_box_iou)

1522

Aditya Oke's avatar
Aditya Oke committed
1523
class TestCompleteBoxIou(TestIouBase):
1524
1525
1526
1527
1528
1529
    int_expected = [
        [1.0000, 0.1875, -0.4444],
        [0.1875, 1.0000, -0.5625],
        [-0.4444, -0.5625, 1.0000],
        [-0.0781, 0.1875, -0.6267],
    ]
Aditya Oke's avatar
Aditya Oke committed
1530
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1531
1532

    @pytest.mark.parametrize(
1533
        "actual_box1, actual_box2, dtypes, atol, expected",
1534
        [
1535
1536
1537
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1538
1539
        ],
    )
1540
1541
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.complete_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1542

Aditya Oke's avatar
Aditya Oke committed
1543
1544
    def test_iou_jit(self):
        self._run_jit_test(ops.complete_box_iou, INT_BOXES)
1545

1546
1547
1548
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.complete_box_iou)

1549

Aditya Oke's avatar
Aditya Oke committed
1550
1551
1552
1553
1554
def get_boxes(dtype, device):
    box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
    box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
    box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
    box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)
1555

Aditya Oke's avatar
Aditya Oke committed
1556
1557
    box1s = torch.stack([box2, box2], dim=0)
    box2s = torch.stack([box3, box4], dim=0)
1558

Aditya Oke's avatar
Aditya Oke committed
1559
    return box1, box2, box3, box4, box1s, box2s
1560

Aditya Oke's avatar
Aditya Oke committed
1561

Aditya Oke's avatar
Aditya Oke committed
1562
1563
1564
1565
def assert_iou_loss(iou_fn, box1, box2, expected_loss, device, reduction="none"):
    computed_loss = iou_fn(box1, box2, reduction=reduction)
    expected_loss = torch.tensor(expected_loss, device=device)
    torch.testing.assert_close(computed_loss, expected_loss)
1566
1567


Aditya Oke's avatar
Aditya Oke committed
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
def assert_empty_loss(iou_fn, dtype, device):
    box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
    box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
    loss = iou_fn(box1, box2, reduction="mean")
    loss.backward()
    torch.testing.assert_close(loss, torch.tensor(0.0, device=device))
    assert box1.grad is not None, "box1.grad should not be None after backward is called"
    assert box2.grad is not None, "box2.grad should not be None after backward is called"
    loss = iou_fn(box1, box2, reduction="none")
    assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty"
Aditya Oke's avatar
Aditya Oke committed
1578

Aditya Oke's avatar
Aditya Oke committed
1579

Aditya Oke's avatar
Aditya Oke committed
1580
1581
class TestGeneralizedBoxIouLoss:
    # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py
1582
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1583
1584
1585
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_giou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1586

Aditya Oke's avatar
Aditya Oke committed
1587
1588
        # Identical boxes should have loss of 0
        assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1589

Aditya Oke's avatar
Aditya Oke committed
1590
1591
        # quarter size box inside other box = IoU of 0.25
        assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1592

Aditya Oke's avatar
Aditya Oke committed
1593
1594
1595
        # Two side by side boxes, area=union
        # IoU=0 and GIoU=0 (loss 1.0)
        assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1596

Aditya Oke's avatar
Aditya Oke committed
1597
1598
1599
        # Two diagonally adjacent boxes, area=2*union
        # IoU=0 and GIoU=-0.5 (loss 1.5)
        assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1600

Aditya Oke's avatar
Aditya Oke committed
1601
1602
1603
        # Test batched loss and reductions
        assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, device=device, reduction="sum")
        assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, device=device, reduction="mean")
Yassine Alouini's avatar
Yassine Alouini committed
1604

1605
1606
1607
1608
1609
        # Test reduction value
        # reduction value other than ["none", "mean", "sum"] should raise a ValueError
        with pytest.raises(ValueError, match="Invalid"):
            ops.generalized_box_iou_loss(box1s, box2s, reduction="xyz")

1610
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1611
1612
1613
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_inputs(self, dtype, device):
        assert_empty_loss(ops.generalized_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1614
1615


Aditya Oke's avatar
Aditya Oke committed
1616
1617
class TestCompleteBoxIouLoss:
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1618
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1619
1620
    def test_ciou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1621

Aditya Oke's avatar
Aditya Oke committed
1622
1623
1624
1625
1626
1627
        assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
        assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
Yassine Alouini's avatar
Yassine Alouini committed
1628

1629
1630
1631
        with pytest.raises(ValueError, match="Invalid"):
            ops.complete_box_iou_loss(box1s, box2s, reduction="xyz")

1632
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1633
1634
1635
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_inputs(self, dtype, device):
        assert_empty_loss(ops.complete_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1636
1637


Aditya Oke's avatar
Aditya Oke committed
1638
class TestDistanceBoxIouLoss:
1639
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1640
1641
1642
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_distance_iou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1643

Aditya Oke's avatar
Aditya Oke committed
1644
1645
1646
1647
1648
1649
        assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
        assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
Yassine Alouini's avatar
Yassine Alouini committed
1650

1651
1652
1653
        with pytest.raises(ValueError, match="Invalid"):
            ops.distance_box_iou_loss(box1s, box2s, reduction="xyz")

1654
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1655
1656
1657
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_distance_iou_inputs(self, dtype, device):
        assert_empty_loss(ops.distance_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1658
1659


Aditya Oke's avatar
Aditya Oke committed
1660
1661
1662
1663
class TestFocalLoss:
    def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs):
        def logit(p):
            return torch.log(p / (1 - p))
Yassine Alouini's avatar
Yassine Alouini committed
1664

Aditya Oke's avatar
Aditya Oke committed
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
        def generate_tensor_with_range_type(shape, range_type, **kwargs):
            if range_type != "random_binary":
                low, high = {
                    "small": (0.0, 0.2),
                    "big": (0.8, 1.0),
                    "zeros": (0.0, 0.0),
                    "ones": (1.0, 1.0),
                    "random": (0.0, 1.0),
                }[range_type]
                return torch.testing.make_tensor(shape, low=low, high=high, **kwargs)
            else:
                return torch.randint(0, 2, shape, **kwargs)
Yassine Alouini's avatar
Yassine Alouini committed
1677

Aditya Oke's avatar
Aditya Oke committed
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
        # This function will return inputs and targets with shape: (shape[0]*9, shape[1])
        inputs = []
        targets = []
        for input_range_type, target_range_type in [
            ("small", "zeros"),
            ("small", "ones"),
            ("small", "random_binary"),
            ("big", "zeros"),
            ("big", "ones"),
            ("big", "random_binary"),
            ("random", "zeros"),
            ("random", "ones"),
            ("random", "random_binary"),
        ]:
            inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs)))
            targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs))
Yassine Alouini's avatar
Yassine Alouini committed
1694

Aditya Oke's avatar
Aditya Oke committed
1695
        return torch.cat(inputs), torch.cat(targets)
Yassine Alouini's avatar
Yassine Alouini committed
1696

Aditya Oke's avatar
Aditya Oke committed
1697
1698
    @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
    @pytest.mark.parametrize("gamma", [0, 2])
1699
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [0, 1])
    def test_correct_ratio(self, alpha, gamma, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        # For testing the ratio with manual calculation, we require the reduction to be "none"
        reduction = "none"
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)
Yassine Alouini's avatar
Yassine Alouini committed
1711

Aditya Oke's avatar
Aditya Oke committed
1712
1713
1714
        assert torch.all(
            focal_loss <= ce_loss
        ), "focal loss must be less or equal to cross entropy loss with same input"
Abhijit Deo's avatar
Abhijit Deo committed
1715

Aditya Oke's avatar
Aditya Oke committed
1716
1717
1718
1719
1720
1721
1722
        loss_ratio = (focal_loss / ce_loss).squeeze()
        prob = torch.sigmoid(inputs)
        p_t = prob * targets + (1 - prob) * (1 - targets)
        correct_ratio = (1.0 - p_t) ** gamma
        if alpha >= 0:
            alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
            correct_ratio = correct_ratio * alpha_t
Abhijit Deo's avatar
Abhijit Deo committed
1723

Aditya Oke's avatar
Aditya Oke committed
1724
1725
        tol = 1e-3 if dtype is torch.half else 1e-5
        torch.testing.assert_close(correct_ratio, loss_ratio, atol=tol, rtol=tol)
Abhijit Deo's avatar
Abhijit Deo committed
1726

Aditya Oke's avatar
Aditya Oke committed
1727
    @pytest.mark.parametrize("reduction", ["mean", "sum"])
1728
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [2, 3])
    def test_equal_ce_loss(self, reduction, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        # focal loss should be equal ce_loss if alpha=-1 and gamma=0
        alpha = -1
        gamma = 0
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        inputs_fl = inputs.clone().requires_grad_()
        targets_fl = targets.clone()
        inputs_ce = inputs.clone().requires_grad_()
        targets_ce = targets.clone()
        focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction)
        ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction)
Abhijit Deo's avatar
Abhijit Deo committed
1745

Aditya Oke's avatar
Aditya Oke committed
1746
        torch.testing.assert_close(focal_loss, ce_loss)
Abhijit Deo's avatar
Abhijit Deo committed
1747

Aditya Oke's avatar
Aditya Oke committed
1748
1749
1750
        focal_loss.backward()
        ce_loss.backward()
        torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad)
Abhijit Deo's avatar
Abhijit Deo committed
1751

Aditya Oke's avatar
Aditya Oke committed
1752
1753
1754
    @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
    @pytest.mark.parametrize("gamma", [0, 2])
    @pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
1755
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1756
1757
1758
1759
1760
1761
1762
1763
1764
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [4, 5])
    def test_jit(self, alpha, gamma, reduction, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        script_fn = torch.jit.script(ops.sigmoid_focal_loss)
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1765
        scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
Aditya Oke's avatar
Aditya Oke committed
1766
1767
1768

        tol = 1e-3 if dtype is torch.half else 1e-5
        torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)
Abhijit Deo's avatar
Abhijit Deo committed
1769

1770
    # Raise ValueError for anonymous reduction mode
1771
    @pytest.mark.parametrize("device", cpu_and_cuda())
1772
1773
1774
1775
1776
1777
1778
1779
1780
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_reduction_mode(self, device, dtype, reduction="xyz"):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        torch.random.manual_seed(0)
        inputs, targets = self._generate_diverse_input_target_pair(device=device, dtype=dtype)
        with pytest.raises(ValueError, match="Invalid"):
            ops.sigmoid_focal_loss(inputs, targets, 0.25, 2, reduction)

Abhijit Deo's avatar
Abhijit Deo committed
1781

1782
1783
class TestMasksToBoxes:
    def test_masks_box(self):
Aditya Oke's avatar
Aditya Oke committed
1784
        def masks_box_check(masks, expected, atol=1e-4):
1785
1786
            out = ops.masks_to_boxes(masks)
            assert out.dtype == torch.float
Aditya Oke's avatar
Aditya Oke committed
1787
            torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=True, atol=atol)
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803

        # Check for int type boxes.
        def _get_image():
            assets_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
            mask_path = os.path.join(assets_directory, "masks.tiff")
            image = Image.open(mask_path)
            return image

        def _create_masks(image, masks):
            for index in range(image.n_frames):
                image.seek(index)
                frame = np.array(image)
                masks[index] = torch.tensor(frame)

            return masks

1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
        expected = torch.tensor(
            [
                [127, 2, 165, 40],
                [2, 50, 44, 92],
                [56, 63, 98, 100],
                [139, 68, 175, 104],
                [160, 112, 198, 145],
                [49, 138, 99, 182],
                [108, 148, 152, 213],
            ],
            dtype=torch.float,
        )
1816
1817
1818
1819
1820
1821
1822
1823

        image = _get_image()
        for dtype in [torch.float16, torch.float32, torch.float64]:
            masks = torch.zeros((image.n_frames, image.height, image.width), dtype=dtype)
            masks = _create_masks(image, masks)
            masks_box_check(masks, expected)


1824
class TestStochasticDepth:
1825
    @pytest.mark.parametrize("seed", range(10))
1826
1827
    @pytest.mark.parametrize("p", [0.2, 0.5, 0.8])
    @pytest.mark.parametrize("mode", ["batch", "row"])
1828
1829
    def test_stochastic_depth_random(self, seed, mode, p):
        torch.manual_seed(seed)
1830
1831
1832
        stats = pytest.importorskip("scipy.stats")
        batch_size = 5
        x = torch.ones(size=(batch_size, 3, 4, 4))
1833
        layer = ops.StochasticDepth(p=p, mode=mode)
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
        layer.__repr__()

        trials = 250
        num_samples = 0
        counts = 0
        for _ in range(trials):
            out = layer(x)
            non_zero_count = out.sum(dim=(1, 2, 3)).nonzero().size(0)
            if mode == "batch":
                if non_zero_count == 0:
                    counts += 1
                num_samples += 1
            elif mode == "row":
                counts += batch_size - non_zero_count
                num_samples += batch_size

1850
        p_value = stats.binomtest(counts, num_samples, p=p).pvalue
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
        assert p_value > 0.01

    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("p", (0, 1))
    @pytest.mark.parametrize("mode", ["batch", "row"])
    def test_stochastic_depth(self, seed, mode, p):
        torch.manual_seed(seed)
        batch_size = 5
        x = torch.ones(size=(batch_size, 3, 4, 4))
        layer = ops.StochasticDepth(p=p, mode=mode)

        out = layer(x)
        if p == 0:
            assert out.equal(x)
        elif p == 1:
            assert out.equal(torch.zeros_like(x))
1867

1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
    def make_obj(self, p, mode, wrap=False):
        obj = ops.StochasticDepth(p, mode)
        return StochasticDepthWrapper(obj) if wrap else obj

    @pytest.mark.parametrize("p", (0, 1))
    @pytest.mark.parametrize("mode", ["batch", "row"])
    def test_is_leaf_node(self, p, mode):
        op_obj = self.make_obj(p, mode, wrap=True)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

1882

1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
class TestUtils:
    @pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
    def test_split_normalization_params(self, norm_layer):
        model = models.mobilenet_v3_large(norm_layer=norm_layer)
        params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])

        assert len(params[0]) == 92
        assert len(params[1]) == 82


1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
class TestDropBlock:
    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("dim", [2, 3])
    @pytest.mark.parametrize("p", [0, 0.5])
    @pytest.mark.parametrize("block_size", [5, 11])
    @pytest.mark.parametrize("inplace", [True, False])
    def test_drop_block(self, seed, dim, p, block_size, inplace):
        torch.manual_seed(seed)
        batch_size = 5
        channels = 3
        height = 11
        width = height
        depth = height
        if dim == 2:
            x = torch.ones(size=(batch_size, channels, height, width))
            layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
            feature_size = height * width
        elif dim == 3:
            x = torch.ones(size=(batch_size, channels, depth, height, width))
            layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
            feature_size = depth * height * width
        layer.__repr__()

        out = layer(x)
        if p == 0:
            assert out.equal(x)
        if block_size == height:
            for b, c in product(range(batch_size), range(channels)):
                assert out[b, c].count_nonzero() in (0, feature_size)

    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("dim", [2, 3])
    @pytest.mark.parametrize("p", [0.1, 0.2])
    @pytest.mark.parametrize("block_size", [3])
    @pytest.mark.parametrize("inplace", [False])
    def test_drop_block_random(self, seed, dim, p, block_size, inplace):
        torch.manual_seed(seed)
        batch_size = 5
        channels = 3
        height = 11
        width = height
        depth = height
        if dim == 2:
            x = torch.ones(size=(batch_size, channels, height, width))
            layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
        elif dim == 3:
            x = torch.ones(size=(batch_size, channels, depth, height, width))
            layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)

        trials = 250
        num_samples = 0
        counts = 0
        cell_numel = torch.tensor(x.shape).prod()
        for _ in range(trials):
            with torch.no_grad():
                out = layer(x)
            non_zero_count = out.nonzero().size(0)
            counts += cell_numel - non_zero_count
            num_samples += cell_numel

        assert abs(p - counts / num_samples) / p < 0.15

    def make_obj(self, dim, p, block_size, inplace, wrap=False):
        if dim == 2:
            obj = ops.DropBlock2d(p, block_size, inplace)
        elif dim == 3:
            obj = ops.DropBlock3d(p, block_size, inplace)
        return DropBlockWrapper(obj) if wrap else obj

    @pytest.mark.parametrize("dim", (2, 3))
    @pytest.mark.parametrize("p", [0, 1])
    @pytest.mark.parametrize("block_size", [5, 7])
    @pytest.mark.parametrize("inplace", [True, False])
    def test_is_leaf_node(self, dim, p, block_size, inplace):
        op_obj = self.make_obj(dim, p, block_size, inplace, wrap=True)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs


1975
if __name__ == "__main__":
1976
    pytest.main([__file__])