test_ops.py 79.3 KB
Newer Older
1
import math
2
import os
3
from abc import ABC, abstractmethod
4
from functools import lru_cache
5
from itertools import product
6
from typing import Callable, List, Tuple
7

8
import numpy as np
9
import pytest
10
import torch
11
import torch.fx
12
import torch.nn.functional as F
13
import torch.testing._internal.optests as optests
14
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
15
from PIL import Image
16
from torch import nn, Tensor
17
from torch.autograd import gradcheck
18
from torch.nn.modules.utils import _pair
19
from torchvision import models, ops
20
21
22
from torchvision.models.feature_extraction import get_graph_node_names


23
24
25
26
27
28
29
30
OPTESTS = [
    "test_schema",
    "test_autograd_registration",
    "test_faketensor",
    "test_aot_dispatch_dynamic",
]


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# Context manager for setting deterministic flag and automatically
# resetting it to its original value
class DeterministicGuard:
    def __init__(self, deterministic, *, warn_only=False):
        self.deterministic = deterministic
        self.warn_only = warn_only

    def __enter__(self):
        self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
        self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
        torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)

    def __exit__(self, exception_type, exception_value, traceback):
        torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore)


47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class RoIOpTesterModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 2

    def forward(self, a, b):
        self.layer(a, b)


class MultiScaleRoIAlignModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 3

    def forward(self, a, b, c):
        self.layer(a, b, c)


class DeformConvModuleWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 3

    def forward(self, a, b, c):
        self.layer(a, b, c)


class StochasticDepthWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 1

    def forward(self, a):
        self.layer(a)
85
86


87
88
89
90
91
92
93
94
95
96
class DropBlockWrapper(nn.Module):
    def __init__(self, obj):
        super().__init__()
        self.layer = obj
        self.n_inputs = 1

    def forward(self, a):
        self.layer(a)


97
98
99
100
101
102
103
104
105
class PoolWrapper(nn.Module):
    def __init__(self, pool: nn.Module):
        super().__init__()
        self.pool = pool

    def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:
        return self.pool(imgs, boxes)


106
107
class RoIOpTester(ABC):
    dtype = torch.float64
108
109
    mps_dtype = torch.float32
    mps_backward_atol = 2e-2
110

111
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
112
    @pytest.mark.parametrize("contiguous", (True, False))
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    @pytest.mark.parametrize(
        "x_dtype",
        (
            torch.float16,
            torch.float32,
            torch.float64,
        ),
        ids=str,
    )
    def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
        if device == "mps" and x_dtype is torch.float64:
            pytest.skip("MPS does not support float64")

        rois_dtype = x_dtype if rois_dtype is None else rois_dtype

        tol = 1e-5
        if x_dtype is torch.half:
            if device == "mps":
                tol = 5e-3
            else:
                tol = 4e-3
134
135
        elif x_dtype == torch.bfloat16:
            tol = 5e-3
136

137
        pool_size = 5
138
        # n_channels % (pool_size ** 2) == 0 required for PS operations.
139
        n_channels = 2 * (pool_size**2)
140
        x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
141
142
        if not contiguous:
            x = x.permute(0, 1, 3, 2)
143
144
145
146
147
        rois = torch.tensor(
            [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],  # format is (xyxy)
            dtype=rois_dtype,
            device=device,
        )
148

149
        pool_h, pool_w = pool_size, pool_size
150
151
        with DeterministicGuard(deterministic):
            y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
152
        # the following should be true whether we're running an autocast test or not.
153
        assert y.dtype == x.dtype
154
        gt_y = self.expected_fn(
155
            x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs
156
        )
157

158
        torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
159

160
    @pytest.mark.parametrize("device", cpu_and_cuda())
161
162
163
164
165
166
167
168
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

169
    @pytest.mark.parametrize("device", cpu_and_cuda())
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.float):
        op_obj = self.make_obj().to(device=device)
        graph_module = torch.fx.symbolic_trace(op_obj)
        pool_size = 5
        n_channels = 2 * (pool_size**2)
        x = torch.rand(2, n_channels, 5, 5, dtype=x_dtype, device=device)
        rois = torch.tensor(
            [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],  # format is (xyxy)
            dtype=rois_dtype,
            device=device,
        )
        output_gt = op_obj(x, rois)
        assert output_gt.dtype == x.dtype
        output_fx = graph_module(x, rois)
        assert output_fx.dtype == x.dtype
        tol = 1e-5
        torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)

188
    @pytest.mark.parametrize("seed", range(10))
189
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
190
    @pytest.mark.parametrize("contiguous", (True, False))
191
    def test_backward(self, seed, device, contiguous, deterministic=False):
192
193
194
        atol = self.mps_backward_atol if device == "mps" else 1e-05
        dtype = self.mps_dtype if device == "mps" else self.dtype

195
        torch.random.manual_seed(seed)
196
        pool_size = 2
197
        x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True)
198
199
        if not contiguous:
            x = x.permute(0, 1, 3, 2)
200
        rois = torch.tensor(
201
            [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device  # format is (xyxy)
202
        )
203

204
205
        def func(z):
            return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
206

207
        script_func = self.get_script_fn(rois, pool_size)
208

209
        with DeterministicGuard(deterministic):
210
211
212
213
214
215
216
217
218
219
220
            gradcheck(func, (x,), atol=atol)

        gradcheck(script_func, (x,), atol=atol)

    @needs_mps
    def test_mps_error_inputs(self):
        pool_size = 2
        x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True)
        rois = torch.tensor(
            [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps"  # format is (xyxy)
        )
221

222
223
224
225
226
227
228
        def func(z):
            return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)

        with pytest.raises(
            RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs."
        ):
            gradcheck(func, (x,))
229

230
    @needs_cuda
231
232
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
233
234
235
    def test_autocast(self, x_dtype, rois_dtype):
        with torch.cuda.amp.autocast():
            self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
236
237
238

    def _helper_boxes_shape(self, func):
        # test boxes as Tensor[N, 5]
239
        with pytest.raises(AssertionError):
240
241
242
243
244
            a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
            boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
            func(a, boxes, output_size=(2, 2))

        # test boxes as List[Tensor[N, 4]]
245
        with pytest.raises(AssertionError):
246
247
248
249
            a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
            boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
            ops.roi_pool(a, [boxes], output_size=(2, 2))

250
251
252
253
254
255
256
257
    def _helper_jit_boxes_list(self, model):
        x = torch.rand(2, 1, 10, 10)
        roi = torch.tensor([[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], dtype=torch.float).t()
        rois = [roi, roi]
        scriped = torch.jit.script(model)
        y = scriped(x, rois)
        assert y.shape == (10, 1, 3, 3)

258
    @abstractmethod
259
260
    def fn(*args, **kwargs):
        pass
261

262
263
264
265
    @abstractmethod
    def make_obj(*args, **kwargs):
        pass

266
    @abstractmethod
267
268
    def get_script_fn(*args, **kwargs):
        pass
269

270
    @abstractmethod
271
272
    def expected_fn(*args, **kwargs):
        pass
273

274

275
class TestRoiPool(RoIOpTester):
276
277
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
278

279
280
281
282
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
        obj = ops.RoIPool((pool_h, pool_w), spatial_scale)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

283
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
284
285
        scriped = torch.jit.script(ops.roi_pool)
        return lambda x: scriped(x, rois, pool_size)
286

287
288
289
    def expected_fn(
        self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
    ):
290
291
        if device is None:
            device = torch.device("cpu")
292

293
294
        n_channels = x.size(1)
        y = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
295

296
297
        def get_slice(k, block):
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))
298

299
300
301
        for roi_idx, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
302
            roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
303

304
305
306
            roi_h, roi_w = roi_x.shape[-2:]
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w
307

308
309
310
311
312
313
            for i in range(0, pool_h):
                for j in range(0, pool_w):
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
                    if bin_x.numel() > 0:
                        y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0]
        return y
314

315
    def test_boxes_shape(self):
316
317
        self._helper_boxes_shape(ops.roi_pool)

318
319
320
321
    def test_jit_boxes_list(self):
        model = PoolWrapper(ops.RoIPool(output_size=[3, 3], spatial_scale=1.0))
        self._helper_jit_boxes_list(model)

322

323
class TestPSRoIPool(RoIOpTester):
324
325
    mps_backward_atol = 5e-2

326
327
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
328

329
330
331
332
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
        obj = ops.PSRoIPool((pool_h, pool_w), spatial_scale)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

333
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
334
335
        scriped = torch.jit.script(ops.ps_roi_pool)
        return lambda x: scriped(x, rois, pool_size)
336

337
338
339
    def expected_fn(
        self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
    ):
340
341
342
        if device is None:
            device = torch.device("cpu")
        n_input_channels = x.size(1)
343
        assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
344
345
346
347
348
349
350
351
352
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
        y = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)

        def get_slice(k, block):
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))

        for roi_idx, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
353
            roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368

            roi_height = max(i_end - i_begin, 1)
            roi_width = max(j_end - j_begin, 1)
            bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w)

            for i in range(0, pool_h):
                for j in range(0, pool_w):
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
                    if bin_x.numel() > 0:
                        area = bin_x.size(-2) * bin_x.size(-1)
                        for c_out in range(0, n_output_channels):
                            c_in = c_out * (pool_h * pool_w) + pool_w * i + j
                            t = torch.sum(bin_x[c_in, :, :])
                            y[roi_idx, c_out, i, j] = t / area
        return y
369

370
    def test_boxes_shape(self):
371
372
        self._helper_boxes_shape(ops.ps_roi_pool)

373

374
375
def bilinear_interpolate(data, y, x, snap_border=False):
    height, width = data.shape
376

377
378
379
380
381
    if snap_border:
        if -1 < y <= 0:
            y = 0
        elif height - 1 <= y < height:
            y = height - 1
382

383
384
385
386
        if -1 < x <= 0:
            x = 0
        elif width - 1 <= x < width:
            x = width - 1
387

388
389
390
391
    y_low = int(math.floor(y))
    x_low = int(math.floor(x))
    y_high = y_low + 1
    x_high = x_low + 1
392

393
394
    wy_h = y - y_low
    wx_h = x - x_low
395
    wy_l = 1 - wy_h
396
    wx_l = 1 - wx_h
397

398
    val = 0
399
400
401
402
    for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
        for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
            if 0 <= yp < height and 0 <= xp < width:
                val += wx * wy * data[yp, xp]
403
    return val
404
405


406
class TestRoIAlign(RoIOpTester):
407
408
    mps_backward_atol = 6e-2

AhnDW's avatar
AhnDW committed
409
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
410
411
412
        return ops.RoIAlign(
            (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
        )(x, rois)
413

414
415
416
417
418
419
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, aligned=False, wrap=False):
        obj = ops.RoIAlign(
            (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
        )
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

420
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
421
422
        scriped = torch.jit.script(ops.roi_align)
        return lambda x: scriped(x, rois, pool_size)
423

424
425
426
427
428
429
430
431
432
433
434
435
    def expected_fn(
        self,
        in_data,
        rois,
        pool_h,
        pool_w,
        spatial_scale=1,
        sampling_ratio=-1,
        aligned=False,
        device=None,
        dtype=torch.float64,
    ):
436
437
        if device is None:
            device = torch.device("cpu")
438
439
440
        n_channels = in_data.size(1)
        out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)

441
        offset = 0.5 if aligned else 0.0
AhnDW's avatar
AhnDW committed
442

443
444
        for r, roi in enumerate(rois):
            batch_idx = int(roi[0])
AhnDW's avatar
AhnDW committed
445
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - offset for x in roi[1:])
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464

            roi_h = i_end - i_begin
            roi_w = j_end - j_begin
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w

            for i in range(0, pool_h):
                start_h = i_begin + i * bin_h
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
                for j in range(0, pool_w):
                    start_w = j_begin + j * bin_w
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))

                    for channel in range(0, n_channels):
                        val = 0
                        for iy in range(0, grid_h):
                            y = start_h + (iy + 0.5) * bin_h / grid_h
                            for ix in range(0, grid_w):
                                x = start_w + (ix + 0.5) * bin_w / grid_w
465
                                val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True)
466
467
468
                        val /= grid_h * grid_w

                        out_data[r, channel, i, j] = val
469
470
        return out_data

471
    def test_boxes_shape(self):
472
473
        self._helper_boxes_shape(ops.roi_align)

474
    @pytest.mark.parametrize("aligned", (True, False))
475
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
476
    @pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64))  # , ids=str)
477
    @pytest.mark.parametrize("contiguous", (True, False))
478
    @pytest.mark.parametrize("deterministic", (True, False))
479
    @pytest.mark.opcheck_only_one()
480
    def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None):
481
482
        if deterministic and device == "cpu":
            pytest.skip("cpu is always deterministic, don't retest")
483
        super().test_forward(
484
485
486
487
488
489
            device=device,
            contiguous=contiguous,
            deterministic=deterministic,
            x_dtype=x_dtype,
            rois_dtype=rois_dtype,
            aligned=aligned,
490
        )
491

492
    @needs_cuda
493
    @pytest.mark.parametrize("aligned", (True, False))
494
    @pytest.mark.parametrize("deterministic", (True, False))
495
496
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
497
    @pytest.mark.opcheck_only_one()
498
    def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
499
        with torch.cuda.amp.autocast():
500
            self.test_forward(
501
502
503
504
505
506
                torch.device("cuda"),
                contiguous=False,
                deterministic=deterministic,
                aligned=aligned,
                x_dtype=x_dtype,
                rois_dtype=rois_dtype,
507
            )
508

509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
    @pytest.mark.parametrize("aligned", (True, False))
    @pytest.mark.parametrize("deterministic", (True, False))
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.bfloat16))
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.bfloat16))
    def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
        with torch.cpu.amp.autocast():
            self.test_forward(
                torch.device("cpu"),
                contiguous=False,
                deterministic=deterministic,
                aligned=aligned,
                x_dtype=x_dtype,
                rois_dtype=rois_dtype,
            )

524
    @pytest.mark.parametrize("seed", range(10))
525
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
526
527
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("deterministic", (True, False))
528
    @pytest.mark.opcheck_only_one()
529
530
531
532
533
    def test_backward(self, seed, device, contiguous, deterministic):
        if deterministic and device == "cpu":
            pytest.skip("cpu is always deterministic, don't retest")
        super().test_backward(seed, device, contiguous, deterministic)

534
535
536
537
538
539
    def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
        rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
        rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,))  # set batch index
        rois[:, 3:] += rois[:, 1:3]  # make sure boxes aren't degenerate
        return rois

540
541
542
    @pytest.mark.parametrize("aligned", (True, False))
    @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 10), (0.1, 50)))
    @pytest.mark.parametrize("qdtype", (torch.qint8, torch.quint8, torch.qint32))
543
    @pytest.mark.opcheck_only_one()
544
    def test_qroialign(self, aligned, scale, zero_point, qdtype):
545
546
547
548
549
550
551
        """Make sure quantized version of RoIAlign is close to float version"""
        pool_size = 5
        img_size = 10
        n_channels = 2
        num_imgs = 1
        dtype = torch.float

552
553
554
555
556
557
558
        x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype)
        qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qdtype)

        rois = self._make_rois(img_size, num_imgs, dtype)
        qrois = torch.quantize_per_tensor(rois, scale=scale, zero_point=zero_point, dtype=qdtype)

        x, rois = qx.dequantize(), qrois.dequantize()  # we want to pass the same inputs
559

560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        y = ops.roi_align(
            x,
            rois,
            output_size=pool_size,
            spatial_scale=1,
            sampling_ratio=-1,
            aligned=aligned,
        )
        qy = ops.roi_align(
            qx,
            qrois,
            output_size=pool_size,
            spatial_scale=1,
            sampling_ratio=-1,
            aligned=aligned,
        )

        # The output qy is itself a quantized tensor and there might have been a loss of info when it was
        # quantized. For a fair comparison we need to quantize y as well
        quantized_float_y = torch.quantize_per_tensor(y, scale=scale, zero_point=zero_point, dtype=qdtype)

        try:
            # Ideally, we would assert this, which passes with (scale, zero) == (1, 0)
            assert (qy == quantized_float_y).all()
        except AssertionError:
            # But because the computation aren't exactly the same between the 2 RoIAlign procedures, some
            # rounding error may lead to a difference of 2 in the output.
            # For example with (scale, zero) = (2, 10), 45.00000... will be quantized to 44
            # but 45.00000001 will be rounded to 46. We make sure below that:
            # - such discrepancies between qy and quantized_float_y are very rare (less then 5%)
            # - any difference between qy and quantized_float_y is == scale
            diff_idx = torch.where(qy != quantized_float_y)
            num_diff = diff_idx[0].numel()
593
            assert num_diff / qy.numel() < 0.05
594
595
596
597
598
599
600

            abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize())
            t_scale = torch.full_like(abs_diff, fill_value=scale)
            torch.testing.assert_close(abs_diff, t_scale, rtol=1e-5, atol=1e-5)

    def test_qroi_align_multiple_images(self):
        dtype = torch.float
601
602
        x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype)
        qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8)
603
        rois = self._make_rois(img_size=10, num_imgs=2, dtype=dtype, num_rois=10)
604
        qrois = torch.quantize_per_tensor(rois, scale=1, zero_point=0, dtype=torch.qint8)
605
606
        with pytest.raises(RuntimeError, match="Only one image per batch is allowed"):
            ops.roi_align(qx, qrois, output_size=5)
607

608
609
610
611
    def test_jit_boxes_list(self):
        model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1))
        self._helper_jit_boxes_list(model)

612

613
class TestPSRoIAlign(RoIOpTester):
614
615
    mps_backward_atol = 5e-2

616
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
617
        return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
618

619
620
621
622
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, wrap=False):
        obj = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
        return RoIOpTesterModuleWrapper(obj) if wrap else obj

623
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
624
625
        scriped = torch.jit.script(ops.ps_roi_align)
        return lambda x: scriped(x, rois, pool_size)
626

627
628
629
    def expected_fn(
        self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, sampling_ratio=-1, dtype=torch.float64
    ):
630
631
        if device is None:
            device = torch.device("cpu")
632
        n_input_channels = in_data.size(1)
633
        assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
        out_data = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)

        for r, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - 0.5 for x in roi[1:])

            roi_h = i_end - i_begin
            roi_w = j_end - j_begin
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w

            for i in range(0, pool_h):
                start_h = i_begin + i * bin_h
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
                for j in range(0, pool_w):
                    start_w = j_begin + j * bin_w
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
                    for c_out in range(0, n_output_channels):
                        c_in = c_out * (pool_h * pool_w) + pool_w * i + j

                        val = 0
                        for iy in range(0, grid_h):
                            y = start_h + (iy + 0.5) * bin_h / grid_h
                            for ix in range(0, grid_w):
                                x = start_w + (ix + 0.5) * bin_w / grid_w
660
                                val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True)
661
662
663
664
                        val /= grid_h * grid_w

                        out_data[r, c_out, i, j] = val
        return out_data
665

666
    def test_boxes_shape(self):
667
668
        self._helper_boxes_shape(ops.ps_roi_align)

669

670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
@pytest.mark.parametrize(
    "op",
    (
        torch.ops.torchvision.roi_pool,
        torch.ops.torchvision.ps_roi_pool,
        torch.ops.torchvision.roi_align,
        torch.ops.torchvision.ps_roi_align,
    ),
)
@pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64))
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("requires_grad", (True, False))
def test_roi_opcheck(op, dtype, device, requires_grad):
    # This manually calls opcheck() on the roi ops. We do that instead of
    # relying on opcheck.generate_opcheck_tests() as e.g. done for nms, because
    # pytest and generate_opcheck_tests() don't interact very well when it comes
    # to skipping tests - and these ops need to skip the MPS tests since MPS we
    # don't support dynamic shapes yet for MPS.
    rois = torch.tensor(
        [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],
        dtype=dtype,
        device=device,
        requires_grad=requires_grad,
    )
    pool_size = 5
    num_channels = 2 * (pool_size**2)
    x = torch.rand(2, num_channels, 10, 10, dtype=dtype, device=device)

    kwargs = dict(rois=rois, spatial_scale=1, pooled_height=pool_size, pooled_width=pool_size)
    if op in (torch.ops.torchvision.roi_align, torch.ops.torchvision.ps_roi_align):
        kwargs["sampling_ratio"] = -1
    if op is torch.ops.torchvision.roi_align:
        kwargs["aligned"] = True

    optests.opcheck(op, args=(x,), kwargs=kwargs)


707
class TestMultiScaleRoIAlign:
708
709
710
711
712
713
    def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False):
        if fmap_names is None:
            fmap_names = ["0"]
        obj = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)
        return MultiScaleRoIAlignModuleWrapper(obj) if wrap else obj

714
    def test_msroialign_repr(self):
715
        fmap_names = ["0"]
716
717
718
        output_size = (7, 7)
        sampling_ratio = 2
        # Pass mock feature map names
719
        t = self.make_obj(fmap_names, output_size, sampling_ratio, wrap=False)
720
721

        # Check integrity of object __repr__ attribute
722
723
724
725
        expected_string = (
            f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, "
            f"sampling_ratio={sampling_ratio})"
        )
726
        assert repr(t) == expected_string
727

728
    @pytest.mark.parametrize("device", cpu_and_cuda())
729
730
731
732
733
734
735
736
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

737

738
739
class TestNMS:
    def _reference_nms(self, boxes, scores, iou_threshold):
740
741
        """
        Args:
742
743
744
            boxes: boxes in corner-form
            scores: probabilities
            iou_threshold: intersection over union threshold
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
        Returns:
             picked: a list of indexes of the kept boxes
        """
        picked = []
        _, indexes = scores.sort(descending=True)
        while len(indexes) > 0:
            current = indexes[0]
            picked.append(current.item())
            if len(indexes) == 1:
                break
            current_box = boxes[current, :]
            indexes = indexes[1:]
            rest_boxes = boxes[indexes, :]
            iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1)
            indexes = indexes[iou <= iou_threshold]

        return torch.as_tensor(picked)

763
764
765
766
767
    def _create_tensors_with_iou(self, N, iou_thresh):
        # force last box to have a pre-defined iou with the first box
        # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
        # then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
        # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
768
769
770
        # Adjust the threshold upward a bit with the intent of creating
        # at least one box that exceeds (barely) the threshold and so
        # should be suppressed.
771
        boxes = torch.rand(N, 4) * 100
772
773
774
        boxes[:, 2:] += boxes[:, :2]
        boxes[-1, :] = boxes[0, :]
        x0, y0, x1, y1 = boxes[-1].tolist()
775
        iou_thresh += 1e-5
776
        boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
777
778
779
        scores = torch.rand(N)
        return boxes, scores

780
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
781
    @pytest.mark.parametrize("seed", range(10))
782
    @pytest.mark.opcheck_only_one()
783
784
    def test_nms_ref(self, iou, seed):
        torch.random.manual_seed(seed)
785
        err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
786
787
788
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        keep_ref = self._reference_nms(boxes, scores, iou)
        keep = ops.nms(boxes, scores, iou)
789
        torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou))
790
791
792
793
794
795
796
797
798
799
800

    def test_nms_input_errors(self):
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(4), torch.rand(3), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 5), torch.rand(3), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 4), torch.rand(4), 0.5)

801
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
802
    @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10)))
803
    @pytest.mark.opcheck_only_one()
804
    def test_qnms(self, iou, scale, zero_point):
805
        # Note: we compare qnms vs nms instead of qnms vs reference implementation.
806
        # This is because with the int conversion, the trick used in _create_tensors_with_iou
807
        # doesn't really work (in fact, nms vs reference implem will also fail with ints)
808
        err_msg = "NMS and QNMS give different results for IoU={}"
809
        boxes, scores = self._create_tensors_with_iou(1000, iou)
810
        scores *= 100  # otherwise most scores would be 0 or 1 after int conversion
811

812
813
        qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point, dtype=torch.quint8)
        qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point, dtype=torch.quint8)
814

815
816
        boxes = qboxes.dequantize()
        scores = qscores.dequantize()
817

818
819
        keep = ops.nms(boxes, scores, iou)
        qkeep = ops.nms(qboxes, qscores, iou)
820

821
        torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
822

823
824
825
826
827
828
829
    @pytest.mark.parametrize(
        "device",
        (
            pytest.param("cuda", marks=pytest.mark.needs_cuda),
            pytest.param("mps", marks=pytest.mark.needs_mps),
        ),
    )
830
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
831
    @pytest.mark.opcheck_only_one()
832
833
    def test_nms_gpu(self, iou, device, dtype=torch.float64):
        dtype = torch.float32 if device == "mps" else dtype
834
        tol = 1e-3 if dtype is torch.half else 1e-5
835
        err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
836

837
838
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        r_cpu = ops.nms(boxes, scores, iou)
839
        r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)
840

841
        is_eq = torch.allclose(r_cpu, r_gpu.cpu())
842
843
844
        if not is_eq:
            # if the indices are not the same, ensure that it's because the scores
            # are duplicate
845
            is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
846
847
848
        assert is_eq, err_msg.format(iou)

    @needs_cuda
849
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
850
    @pytest.mark.parametrize("dtype", (torch.float, torch.half))
851
    @pytest.mark.opcheck_only_one()
852
853
    def test_autocast(self, iou, dtype):
        with torch.cuda.amp.autocast():
854
855
            self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")

856
857
858
859
860
861
862
863
864
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
    @pytest.mark.parametrize("dtype", (torch.float, torch.bfloat16))
    def test_autocast_cpu(self, iou, dtype):
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        with torch.cpu.amp.autocast():
            keep_ref_float = ops.nms(boxes.to(dtype).float(), scores.to(dtype).float(), iou)
            keep_dtype = ops.nms(boxes.to(dtype), scores.to(dtype), iou)
        torch.testing.assert_close(keep_ref_float, keep_dtype)

865
866
867
868
869
870
871
    @pytest.mark.parametrize(
        "device",
        (
            pytest.param("cuda", marks=pytest.mark.needs_cuda),
            pytest.param("mps", marks=pytest.mark.needs_mps),
        ),
    )
872
    @pytest.mark.opcheck_only_one()
873
    def test_nms_float16(self, device):
874
875
876
877
878
879
        boxes = torch.tensor(
            [
                [285.3538, 185.5758, 1193.5110, 851.4551],
                [285.1472, 188.7374, 1192.4984, 851.0669],
                [279.2440, 197.9812, 1189.4746, 849.2019],
            ]
880
881
        ).to(device)
        scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device)
882
883
884
885

        iou_thres = 0.2
        keep32 = ops.nms(boxes, scores, iou_thres)
        keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
886
        assert_equal(keep32, keep16)
887

888
    @pytest.mark.parametrize("seed", range(10))
889
    @pytest.mark.opcheck_only_one()
890
    def test_batched_nms_implementations(self, seed):
891
        """Make sure that both implementations of batched_nms yield identical results"""
892
        torch.random.manual_seed(seed)
893
894

        num_boxes = 1000
895
        iou_threshold = 0.9
896
897
898
899
900
901
902
903
904
905

        boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
        assert max(boxes[:, 0]) < min(boxes[:, 2])  # x1 < x2
        assert max(boxes[:, 1]) < min(boxes[:, 3])  # y1 < y2

        scores = torch.rand(num_boxes)
        idxs = torch.randint(0, 4, size=(num_boxes,))
        keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
        keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)

906
907
908
        torch.testing.assert_close(
            keep_vanilla, keep_trick, msg="The vanilla and the trick implementation yield different nms outputs."
        )
909
910
911

        # Also make sure an empty tensor is returned if boxes is empty
        empty = torch.empty((0,), dtype=torch.int64)
912
        torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None))
913

914

915
916
917
918
919
920
921
922
923
optests.generate_opcheck_tests(
    testcase=TestNMS,
    namespaces=["torchvision"],
    failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
    additional_decorators=[],
    test_utils=OPTESTS,
)


924
925
926
class TestDeformConv:
    dtype = torch.float64

927
    def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
        stride_h, stride_w = _pair(stride)
        pad_h, pad_w = _pair(padding)
        dil_h, dil_w = _pair(dilation)
        weight_h, weight_w = weight.shape[-2:]

        n_batches, n_in_channels, in_h, in_w = x.shape
        n_out_channels = weight.shape[0]

        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

        n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w)
        in_c_per_offset_grp = n_in_channels // n_offset_grps

        n_weight_grps = n_in_channels // weight.shape[1]
        in_c_per_weight_grp = weight.shape[1]
        out_c_per_weight_grp = n_out_channels // n_weight_grps

        out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
        for b in range(n_batches):
            for c_out in range(n_out_channels):
                for i in range(out_h):
                    for j in range(out_w):
                        for di in range(weight_h):
                            for dj in range(weight_w):
                                for c in range(in_c_per_weight_grp):
                                    weight_grp = c_out // out_c_per_weight_grp
                                    c_in = weight_grp * in_c_per_weight_grp + c

                                    offset_grp = c_in // in_c_per_offset_grp
958
959
                                    mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj
                                    offset_idx = 2 * mask_idx
960
961
962
963

                                    pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
                                    pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]

964
965
966
967
                                    mask_value = 1.0
                                    if mask is not None:
                                        mask_value = mask[b, mask_idx, i, j]

968
969
970
971
972
                                    out[b, c_out, i, j] += (
                                        mask_value
                                        * weight[c_out, c, di, dj]
                                        * bilinear_interpolate(x[b, c_in, :, :], pi, pj)
                                    )
973
974
975
        out += bias.view(1, n_out_channels, 1, 1)
        return out

976
    @lru_cache(maxsize=None)
977
    def get_fn_args(self, device, contiguous, batch_sz, dtype):
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
        n_in_channels = 6
        n_out_channels = 2
        n_weight_grps = 2
        n_offset_grps = 3

        stride = (2, 1)
        pad = (1, 0)
        dilation = (2, 1)

        stride_h, stride_w = stride
        pad_h, pad_w = pad
        dil_h, dil_w = dilation
        weight_h, weight_w = (3, 2)
        in_h, in_w = (5, 4)

        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

996
        x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True)
997

998
999
1000
1001
1002
1003
1004
1005
1006
        offset = torch.randn(
            batch_sz,
            n_offset_grps * 2 * weight_h * weight_w,
            out_h,
            out_w,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
1007

1008
1009
1010
        mask = torch.randn(
            batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, device=device, dtype=dtype, requires_grad=True
        )
1011

1012
1013
1014
1015
1016
1017
1018
1019
1020
        weight = torch.randn(
            n_out_channels,
            n_in_channels // n_weight_grps,
            weight_h,
            weight_w,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
1021

1022
        bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True)
1023
1024
1025
1026

        if not contiguous:
            x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
            offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
1027
            mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
1028
1029
            weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)

1030
        return x, weight, offset, mask, bias, stride, pad, dilation
1031

1032
1033
1034
1035
1036
1037
    def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, wrap=False):
        obj = ops.DeformConv2d(
            in_channels, out_channels, kernel_size, stride=(2, 1), padding=(1, 0), dilation=(2, 1), groups=groups
        )
        return DeformConvModuleWrapper(obj) if wrap else obj

1038
    @pytest.mark.parametrize("device", cpu_and_cuda())
1039
1040
1041
1042
1043
1044
1045
1046
    def test_is_leaf_node(self, device):
        op_obj = self.make_obj(wrap=True).to(device=device)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

1047
    @pytest.mark.parametrize("device", cpu_and_cuda())
1048
1049
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("batch_sz", (0, 33))
1050
    @pytest.mark.opcheck_only_one()
1051
1052
    def test_forward(self, device, contiguous, batch_sz, dtype=None):
        dtype = dtype or self.dtype
1053
        x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
1054
1055
1056
1057
        in_channels = 6
        out_channels = 2
        kernel_size = (3, 2)
        groups = 2
Nicolas Hug's avatar
Nicolas Hug committed
1058
        tol = 2e-3 if dtype is torch.half else 1e-5
1059

1060
1061
1062
        layer = self.make_obj(in_channels, out_channels, kernel_size, groups, wrap=False).to(
            device=x.device, dtype=dtype
        )
1063
        res = layer(x, offset, mask)
1064
1065
1066

        weight = layer.weight.data
        bias = layer.bias.data
1067
1068
        expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)

1069
        torch.testing.assert_close(
1070
            res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1071
        )
1072
1073
1074
1075

        # no modulation test
        res = layer(x, offset)
        expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
1076

1077
        torch.testing.assert_close(
1078
            res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1079
        )
1080

1081
1082
1083
1084
1085
    def test_wrong_sizes(self):
        in_channels = 6
        out_channels = 2
        kernel_size = (3, 2)
        groups = 2
1086
1087
1088
1089
1090
1091
        x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(
            "cpu", contiguous=True, batch_sz=10, dtype=self.dtype
        )
        layer = ops.DeformConv2d(
            in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups
        )
1092
        with pytest.raises(RuntimeError, match="the shape of the offset"):
1093
            wrong_offset = torch.rand_like(offset[:, :2])
1094
            layer(x, wrong_offset)
1095

1096
        with pytest.raises(RuntimeError, match=r"mask.shape\[1\] is not valid"):
1097
            wrong_mask = torch.rand_like(mask[:, :2])
1098
            layer(x, offset, wrong_mask)
1099

1100
    @pytest.mark.parametrize("device", cpu_and_cuda())
1101
1102
    @pytest.mark.parametrize("contiguous", (True, False))
    @pytest.mark.parametrize("batch_sz", (0, 33))
1103
    @pytest.mark.opcheck_only_one()
1104
    def test_backward(self, device, contiguous, batch_sz):
1105
1106
1107
        x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
            device, contiguous, batch_sz, self.dtype
        )
1108
1109

        def func(x_, offset_, mask_, weight_, bias_):
1110
1111
1112
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_
            )
1113

1114
        gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True)
1115
1116

        def func_no_mask(x_, offset_, weight_, bias_):
1117
1118
1119
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None
            )
1120

1121
        gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True)
1122
1123
1124
1125

        @torch.jit.script
        def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
            # type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=mask_
            )

        gradcheck(
            lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
            (x, offset, mask, weight, bias),
            nondet_tol=1e-5,
            fast_mode=True,
        )
1136
1137

        @torch.jit.script
1138
1139
        def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
            # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
            return ops.deform_conv2d(
                x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=None
            )

        gradcheck(
            lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
            (x, offset, weight, bias),
            nondet_tol=1e-5,
            fast_mode=True,
        )
1150

1151
    @needs_cuda
1152
    @pytest.mark.parametrize("contiguous", (True, False))
1153
    @pytest.mark.opcheck_only_one()
1154
    def test_compare_cpu_cuda_grads(self, contiguous):
1155
1156
1157
        # Test from https://github.com/pytorch/vision/issues/2598
        # Run on CUDA only

1158
1159
        # compare grads computed on CUDA with grads computed on CPU
        true_cpu_grads = None
1160

1161
1162
1163
1164
        init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
        img = torch.randn(8, 9, 1000, 110)
        offset = torch.rand(8, 2 * 3 * 3, 1000, 110)
        mask = torch.rand(8, 3 * 3, 1000, 110)
1165

1166
1167
1168
1169
1170
1171
1172
        if not contiguous:
            img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
            offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
            mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
            weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
        else:
            weight = init_weight
1173

1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
        for d in ["cpu", "cuda"]:
            out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
            out.mean().backward()
            if true_cpu_grads is None:
                true_cpu_grads = init_weight.grad
                assert true_cpu_grads is not None
            else:
                assert init_weight.grad is not None
                res_grads = init_weight.grad.to("cpu")
                torch.testing.assert_close(true_cpu_grads, res_grads)

    @needs_cuda
1186
1187
    @pytest.mark.parametrize("batch_sz", (0, 33))
    @pytest.mark.parametrize("dtype", (torch.float, torch.half))
1188
    @pytest.mark.opcheck_only_one()
1189
1190
1191
1192
    def test_autocast(self, batch_sz, dtype):
        with torch.cuda.amp.autocast():
            self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)

1193
1194
1195
1196
    def test_forward_scriptability(self):
        # Non-regression test for https://github.com/pytorch/vision/issues/4078
        torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))

1197

1198
1199
1200
1201
1202
1203
1204
1205
1206
optests.generate_opcheck_tests(
    testcase=TestDeformConv,
    namespaces=["torchvision"],
    failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
    additional_decorators=[],
    test_utils=OPTESTS,
)


1207
class TestFrozenBNT:
1208
1209
    def test_frozenbatchnorm2d_repr(self):
        num_features = 32
1210
1211
        eps = 1e-5
        t = ops.misc.FrozenBatchNorm2d(num_features, eps=eps)
1212
1213

        # Check integrity of object __repr__ attribute
1214
        expected_string = f"FrozenBatchNorm2d({num_features}, eps={eps})"
1215
        assert repr(t) == expected_string
1216

1217
1218
1219
    @pytest.mark.parametrize("seed", range(10))
    def test_frozenbatchnorm2d_eps(self, seed):
        torch.random.manual_seed(seed)
1220
1221
        sample_size = (4, 32, 28, 28)
        x = torch.rand(sample_size)
1222
1223
1224
1225
1226
1227
1228
        state_dict = dict(
            weight=torch.rand(sample_size[1]),
            bias=torch.rand(sample_size[1]),
            running_mean=torch.rand(sample_size[1]),
            running_var=torch.rand(sample_size[1]),
            num_batches_tracked=torch.tensor(100),
        )
1229

1230
        # Check that default eps is equal to the one of BN
1231
1232
        fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
        fbn.load_state_dict(state_dict, strict=False)
1233
        bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
1234
1235
        bn.load_state_dict(state_dict)
        # Difference is expected to fall in an acceptable range
1236
        torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
1237
1238
1239
1240
1241
1242

        # Check computation for eps > 0
        fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
        fbn.load_state_dict(state_dict, strict=False)
        bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
        bn.load_state_dict(state_dict)
1243
        torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
1244

1245

Aditya Oke's avatar
Aditya Oke committed
1246
class TestBoxConversionToRoi:
1247
1248
1249
    def _get_box_sequences():
        # Define here the argument type of `boxes` supported by region pooling operations
        box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float)
1250
1251
1252
1253
        box_list = [
            torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
            torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
        ]
1254
1255
1256
        box_tuple = tuple(box_list)
        return box_tensor, box_list, box_tuple

1257
    @pytest.mark.parametrize("box_sequence", _get_box_sequences())
1258
    def test_check_roi_boxes_shape(self, box_sequence):
1259
        # Ensure common sequences of tensors are supported
1260
        ops._utils.check_roi_boxes_shape(box_sequence)
1261

1262
    @pytest.mark.parametrize("box_sequence", _get_box_sequences())
1263
    def test_convert_boxes_to_roi_format(self, box_sequence):
1264
1265
        # Ensure common sequences of tensors yield the same result
        ref_tensor = None
1266
1267
1268
1269
        if ref_tensor is None:
            ref_tensor = box_sequence
        else:
            assert_equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence))
1270
1271


Aditya Oke's avatar
Aditya Oke committed
1272
class TestBoxConvert:
1273
    def test_bbox_same(self):
1274
1275
1276
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
1277

1278
        exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
1279

1280
1281
1282
1283
        assert exp_xyxy.size() == torch.Size([4, 4])
        assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy)
        assert_equal(ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh"), exp_xyxy)
        assert_equal(ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh"), exp_xyxy)
1284
1285
1286
1287

    def test_bbox_xyxy_xywh(self):
        # Simple test convert boxes to xywh and back. Make sure they are same.
        # box_tensor is in x1 y1 x2 y2 format.
1288
1289
1290
1291
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
        exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
1292

1293
        assert exp_xywh.size() == torch.Size([4, 4])
1294
        box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
1295
        assert_equal(box_xywh, exp_xywh)
1296
1297
1298

        # Reverse conversion
        box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy")
1299
        assert_equal(box_xyxy, box_tensor)
1300
1301

    def test_bbox_xyxy_cxcywh(self):
Aditya Oke's avatar
Aditya Oke committed
1302
        # Simple test convert boxes to cxcywh and back. Make sure they are same.
1303
        # box_tensor is in x1 y1 x2 y2 format.
1304
1305
1306
1307
1308
1309
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
        exp_cxcywh = torch.tensor(
            [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
        )
1310

1311
        assert exp_cxcywh.size() == torch.Size([4, 4])
1312
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
1313
        assert_equal(box_cxcywh, exp_cxcywh)
1314
1315
1316

        # Reverse conversion
        box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy")
1317
        assert_equal(box_xyxy, box_tensor)
1318
1319

    def test_bbox_xywh_cxcywh(self):
1320
1321
1322
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
        )
1323

1324
1325
1326
        exp_cxcywh = torch.tensor(
            [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
        )
1327

1328
        assert exp_cxcywh.size() == torch.Size([4, 4])
1329
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh")
1330
        assert_equal(box_cxcywh, exp_cxcywh)
1331
1332
1333

        # Reverse conversion
        box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
1334
        assert_equal(box_xywh, box_tensor)
1335

1336
1337
    @pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh"])
    @pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy"])
1338
    def test_bbox_invalid(self, inv_infmt, inv_outfmt):
1339
1340
1341
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
        )
1342

1343
1344
        with pytest.raises(ValueError):
            ops.box_convert(box_tensor, inv_infmt, inv_outfmt)
1345
1346

    def test_bbox_convert_jit(self):
1347
1348
1349
        box_tensor = torch.tensor(
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
        )
1350

1351
        scripted_fn = torch.jit.script(ops.box_convert)
1352

1353
        box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
1354
        scripted_xywh = scripted_fn(box_tensor, "xyxy", "xywh")
Aditya Oke's avatar
Aditya Oke committed
1355
        torch.testing.assert_close(scripted_xywh, box_xywh)
1356

1357
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
1358
        scripted_cxcywh = scripted_fn(box_tensor, "xyxy", "cxcywh")
Aditya Oke's avatar
Aditya Oke committed
1359
        torch.testing.assert_close(scripted_cxcywh, box_cxcywh)
1360
1361


Aditya Oke's avatar
Aditya Oke committed
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
class TestBoxArea:
    def area_check(self, box, expected, atol=1e-4):
        out = ops.box_area(box)
        torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)

    @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
    def test_int_boxes(self, dtype):
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
        expected = torch.tensor([10000, 0], dtype=torch.int32)
        self.area_check(box_tensor, expected)

    @pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
    def test_float_boxes(self, dtype):
        box_tensor = torch.tensor(FLOAT_BOXES, dtype=dtype)
        expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
        self.area_check(box_tensor, expected)

    def test_float16_box(self):
        box_tensor = torch.tensor(
            [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16
        )

        expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
        self.area_check(box_tensor, expected, atol=0.01)
1386

Aditya Oke's avatar
Aditya Oke committed
1387
1388
1389
1390
1391
1392
    def test_box_area_jit(self):
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
        expected = ops.box_area(box_tensor)
        scripted_fn = torch.jit.script(ops.box_area)
        scripted_area = scripted_fn(box_tensor)
        torch.testing.assert_close(scripted_area, expected)
1393

Aditya Oke's avatar
Aditya Oke committed
1394

1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
FLOAT_BOXES = [
    [285.3538, 185.5758, 1193.5110, 851.4551],
    [285.1472, 188.7374, 1192.4984, 851.0669],
    [279.2440, 197.9812, 1189.4746, 849.2019],
]


def gen_box(size, dtype=torch.float):
    xy1 = torch.rand((size, 2), dtype=dtype)
    xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
    return torch.cat([xy1, xy2], axis=-1)


Aditya Oke's avatar
Aditya Oke committed
1410
1411
class TestIouBase:
    @staticmethod
1412
    def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
1413
        for dtype in dtypes:
1414
1415
            actual_box1 = torch.tensor(actual_box1, dtype=dtype)
            actual_box2 = torch.tensor(actual_box2, dtype=dtype)
1416
            expected_box = torch.tensor(expected)
1417
            out = target_fn(actual_box1, actual_box2)
Aditya Oke's avatar
Aditya Oke committed
1418
            torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
Aditya Oke's avatar
Aditya Oke committed
1419

Aditya Oke's avatar
Aditya Oke committed
1420
    @staticmethod
1421
1422
    def _run_jit_test(target_fn: Callable, actual_box: List):
        box_tensor = torch.tensor(actual_box, dtype=torch.float)
Aditya Oke's avatar
Aditya Oke committed
1423
1424
1425
1426
        expected = target_fn(box_tensor, box_tensor)
        scripted_fn = torch.jit.script(target_fn)
        scripted_out = scripted_fn(box_tensor, box_tensor)
        torch.testing.assert_close(scripted_out, expected)
Aditya Oke's avatar
Aditya Oke committed
1427

1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
    @staticmethod
    def _cartesian_product(boxes1, boxes2, target_fn: Callable):
        N = boxes1.size(0)
        M = boxes2.size(0)
        result = torch.zeros((N, M))
        for i in range(N):
            for j in range(M):
                result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
        return result

    @staticmethod
    def _run_cartesian_test(target_fn: Callable):
        boxes1 = gen_box(5)
        boxes2 = gen_box(7)
        a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
        b = target_fn(boxes1, boxes2)
1444
        torch.testing.assert_close(a, b)
1445

1446

Aditya Oke's avatar
Aditya Oke committed
1447
class TestBoxIou(TestIouBase):
1448
    int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
Aditya Oke's avatar
Aditya Oke committed
1449
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
Aditya Oke's avatar
Aditya Oke committed
1450

Aditya Oke's avatar
Aditya Oke committed
1451
    @pytest.mark.parametrize(
1452
        "actual_box1, actual_box2, dtypes, atol, expected",
Aditya Oke's avatar
Aditya Oke committed
1453
        [
1454
1455
1456
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
Aditya Oke's avatar
Aditya Oke committed
1457
1458
        ],
    )
1459
1460
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)
Aditya Oke's avatar
Aditya Oke committed
1461

Aditya Oke's avatar
Aditya Oke committed
1462
1463
    def test_iou_jit(self):
        self._run_jit_test(ops.box_iou, INT_BOXES)
Aditya Oke's avatar
Aditya Oke committed
1464

1465
1466
1467
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.box_iou)

1468

Aditya Oke's avatar
Aditya Oke committed
1469
class TestGeneralizedBoxIou(TestIouBase):
1470
    int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
Aditya Oke's avatar
Aditya Oke committed
1471
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1472
1473

    @pytest.mark.parametrize(
1474
        "actual_box1, actual_box2, dtypes, atol, expected",
1475
        [
1476
1477
1478
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1479
1480
        ],
    )
1481
1482
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.generalized_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1483

Aditya Oke's avatar
Aditya Oke committed
1484
1485
    def test_iou_jit(self):
        self._run_jit_test(ops.generalized_box_iou, INT_BOXES)
1486

1487
1488
1489
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.generalized_box_iou)

1490

Aditya Oke's avatar
Aditya Oke committed
1491
class TestDistanceBoxIoU(TestIouBase):
1492
1493
1494
1495
1496
1497
    int_expected = [
        [1.0000, 0.1875, -0.4444],
        [0.1875, 1.0000, -0.5625],
        [-0.4444, -0.5625, 1.0000],
        [-0.0781, 0.1875, -0.6267],
    ]
Aditya Oke's avatar
Aditya Oke committed
1498
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1499

Aditya Oke's avatar
Aditya Oke committed
1500
    @pytest.mark.parametrize(
1501
        "actual_box1, actual_box2, dtypes, atol, expected",
Aditya Oke's avatar
Aditya Oke committed
1502
        [
1503
1504
1505
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
Aditya Oke's avatar
Aditya Oke committed
1506
1507
        ],
    )
1508
1509
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.distance_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1510

Aditya Oke's avatar
Aditya Oke committed
1511
1512
    def test_iou_jit(self):
        self._run_jit_test(ops.distance_box_iou, INT_BOXES)
1513

1514
1515
1516
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.distance_box_iou)

1517

Aditya Oke's avatar
Aditya Oke committed
1518
class TestCompleteBoxIou(TestIouBase):
1519
1520
1521
1522
1523
1524
    int_expected = [
        [1.0000, 0.1875, -0.4444],
        [0.1875, 1.0000, -0.5625],
        [-0.4444, -0.5625, 1.0000],
        [-0.0781, 0.1875, -0.6267],
    ]
Aditya Oke's avatar
Aditya Oke committed
1525
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1526
1527

    @pytest.mark.parametrize(
1528
        "actual_box1, actual_box2, dtypes, atol, expected",
1529
        [
1530
1531
1532
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1533
1534
        ],
    )
1535
1536
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
        self._run_test(ops.complete_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1537

Aditya Oke's avatar
Aditya Oke committed
1538
1539
    def test_iou_jit(self):
        self._run_jit_test(ops.complete_box_iou, INT_BOXES)
1540

1541
1542
1543
    def test_iou_cartesian(self):
        self._run_cartesian_test(ops.complete_box_iou)

1544

Aditya Oke's avatar
Aditya Oke committed
1545
1546
1547
1548
1549
def get_boxes(dtype, device):
    box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
    box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
    box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
    box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)
1550

Aditya Oke's avatar
Aditya Oke committed
1551
1552
    box1s = torch.stack([box2, box2], dim=0)
    box2s = torch.stack([box3, box4], dim=0)
1553

Aditya Oke's avatar
Aditya Oke committed
1554
    return box1, box2, box3, box4, box1s, box2s
1555

Aditya Oke's avatar
Aditya Oke committed
1556

Aditya Oke's avatar
Aditya Oke committed
1557
1558
1559
1560
def assert_iou_loss(iou_fn, box1, box2, expected_loss, device, reduction="none"):
    computed_loss = iou_fn(box1, box2, reduction=reduction)
    expected_loss = torch.tensor(expected_loss, device=device)
    torch.testing.assert_close(computed_loss, expected_loss)
1561
1562


Aditya Oke's avatar
Aditya Oke committed
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
def assert_empty_loss(iou_fn, dtype, device):
    box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
    box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
    loss = iou_fn(box1, box2, reduction="mean")
    loss.backward()
    torch.testing.assert_close(loss, torch.tensor(0.0, device=device))
    assert box1.grad is not None, "box1.grad should not be None after backward is called"
    assert box2.grad is not None, "box2.grad should not be None after backward is called"
    loss = iou_fn(box1, box2, reduction="none")
    assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty"
Aditya Oke's avatar
Aditya Oke committed
1573

Aditya Oke's avatar
Aditya Oke committed
1574

Aditya Oke's avatar
Aditya Oke committed
1575
1576
class TestGeneralizedBoxIouLoss:
    # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py
1577
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1578
1579
1580
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_giou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1581

Aditya Oke's avatar
Aditya Oke committed
1582
1583
        # Identical boxes should have loss of 0
        assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1584

Aditya Oke's avatar
Aditya Oke committed
1585
1586
        # quarter size box inside other box = IoU of 0.25
        assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1587

Aditya Oke's avatar
Aditya Oke committed
1588
1589
1590
        # Two side by side boxes, area=union
        # IoU=0 and GIoU=0 (loss 1.0)
        assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1591

Aditya Oke's avatar
Aditya Oke committed
1592
1593
1594
        # Two diagonally adjacent boxes, area=2*union
        # IoU=0 and GIoU=-0.5 (loss 1.5)
        assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, device=device)
Yassine Alouini's avatar
Yassine Alouini committed
1595

Aditya Oke's avatar
Aditya Oke committed
1596
1597
1598
        # Test batched loss and reductions
        assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, device=device, reduction="sum")
        assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, device=device, reduction="mean")
Yassine Alouini's avatar
Yassine Alouini committed
1599

1600
1601
1602
1603
1604
        # Test reduction value
        # reduction value other than ["none", "mean", "sum"] should raise a ValueError
        with pytest.raises(ValueError, match="Invalid"):
            ops.generalized_box_iou_loss(box1s, box2s, reduction="xyz")

1605
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1606
1607
1608
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_inputs(self, dtype, device):
        assert_empty_loss(ops.generalized_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1609
1610


Aditya Oke's avatar
Aditya Oke committed
1611
1612
class TestCompleteBoxIouLoss:
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1613
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1614
1615
    def test_ciou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1616

Aditya Oke's avatar
Aditya Oke committed
1617
1618
1619
1620
1621
1622
        assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, device=device)
        assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
        assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
Yassine Alouini's avatar
Yassine Alouini committed
1623

1624
1625
1626
        with pytest.raises(ValueError, match="Invalid"):
            ops.complete_box_iou_loss(box1s, box2s, reduction="xyz")

1627
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1628
1629
1630
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_inputs(self, dtype, device):
        assert_empty_loss(ops.complete_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1631
1632


Aditya Oke's avatar
Aditya Oke committed
1633
class TestDistanceBoxIouLoss:
1634
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1635
1636
1637
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_distance_iou_loss(self, dtype, device):
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1638

Aditya Oke's avatar
Aditya Oke committed
1639
1640
1641
1642
1643
1644
        assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, device=device)
        assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
        assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
Yassine Alouini's avatar
Yassine Alouini committed
1645

1646
1647
1648
        with pytest.raises(ValueError, match="Invalid"):
            ops.distance_box_iou_loss(box1s, box2s, reduction="xyz")

1649
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1650
1651
1652
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_empty_distance_iou_inputs(self, dtype, device):
        assert_empty_loss(ops.distance_box_iou_loss, dtype, device)
Yassine Alouini's avatar
Yassine Alouini committed
1653
1654


Aditya Oke's avatar
Aditya Oke committed
1655
1656
1657
1658
class TestFocalLoss:
    def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs):
        def logit(p):
            return torch.log(p / (1 - p))
Yassine Alouini's avatar
Yassine Alouini committed
1659

Aditya Oke's avatar
Aditya Oke committed
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
        def generate_tensor_with_range_type(shape, range_type, **kwargs):
            if range_type != "random_binary":
                low, high = {
                    "small": (0.0, 0.2),
                    "big": (0.8, 1.0),
                    "zeros": (0.0, 0.0),
                    "ones": (1.0, 1.0),
                    "random": (0.0, 1.0),
                }[range_type]
                return torch.testing.make_tensor(shape, low=low, high=high, **kwargs)
            else:
                return torch.randint(0, 2, shape, **kwargs)
Yassine Alouini's avatar
Yassine Alouini committed
1672

Aditya Oke's avatar
Aditya Oke committed
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
        # This function will return inputs and targets with shape: (shape[0]*9, shape[1])
        inputs = []
        targets = []
        for input_range_type, target_range_type in [
            ("small", "zeros"),
            ("small", "ones"),
            ("small", "random_binary"),
            ("big", "zeros"),
            ("big", "ones"),
            ("big", "random_binary"),
            ("random", "zeros"),
            ("random", "ones"),
            ("random", "random_binary"),
        ]:
            inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs)))
            targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs))
Yassine Alouini's avatar
Yassine Alouini committed
1689

Aditya Oke's avatar
Aditya Oke committed
1690
        return torch.cat(inputs), torch.cat(targets)
Yassine Alouini's avatar
Yassine Alouini committed
1691

Aditya Oke's avatar
Aditya Oke committed
1692
1693
    @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
    @pytest.mark.parametrize("gamma", [0, 2])
1694
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [0, 1])
    def test_correct_ratio(self, alpha, gamma, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        # For testing the ratio with manual calculation, we require the reduction to be "none"
        reduction = "none"
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)
Yassine Alouini's avatar
Yassine Alouini committed
1706

Aditya Oke's avatar
Aditya Oke committed
1707
1708
1709
        assert torch.all(
            focal_loss <= ce_loss
        ), "focal loss must be less or equal to cross entropy loss with same input"
Abhijit Deo's avatar
Abhijit Deo committed
1710

Aditya Oke's avatar
Aditya Oke committed
1711
1712
1713
1714
1715
1716
1717
        loss_ratio = (focal_loss / ce_loss).squeeze()
        prob = torch.sigmoid(inputs)
        p_t = prob * targets + (1 - prob) * (1 - targets)
        correct_ratio = (1.0 - p_t) ** gamma
        if alpha >= 0:
            alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
            correct_ratio = correct_ratio * alpha_t
Abhijit Deo's avatar
Abhijit Deo committed
1718

Aditya Oke's avatar
Aditya Oke committed
1719
1720
        tol = 1e-3 if dtype is torch.half else 1e-5
        torch.testing.assert_close(correct_ratio, loss_ratio, atol=tol, rtol=tol)
Abhijit Deo's avatar
Abhijit Deo committed
1721

Aditya Oke's avatar
Aditya Oke committed
1722
    @pytest.mark.parametrize("reduction", ["mean", "sum"])
1723
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [2, 3])
    def test_equal_ce_loss(self, reduction, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        # focal loss should be equal ce_loss if alpha=-1 and gamma=0
        alpha = -1
        gamma = 0
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        inputs_fl = inputs.clone().requires_grad_()
        targets_fl = targets.clone()
        inputs_ce = inputs.clone().requires_grad_()
        targets_ce = targets.clone()
        focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction)
        ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction)
Abhijit Deo's avatar
Abhijit Deo committed
1740

Aditya Oke's avatar
Aditya Oke committed
1741
        torch.testing.assert_close(focal_loss, ce_loss)
Abhijit Deo's avatar
Abhijit Deo committed
1742

Aditya Oke's avatar
Aditya Oke committed
1743
1744
1745
        focal_loss.backward()
        ce_loss.backward()
        torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad)
Abhijit Deo's avatar
Abhijit Deo committed
1746

Aditya Oke's avatar
Aditya Oke committed
1747
1748
1749
    @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
    @pytest.mark.parametrize("gamma", [0, 2])
    @pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
1750
    @pytest.mark.parametrize("device", cpu_and_cuda())
Aditya Oke's avatar
Aditya Oke committed
1751
1752
1753
1754
1755
1756
1757
1758
1759
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    @pytest.mark.parametrize("seed", [4, 5])
    def test_jit(self, alpha, gamma, reduction, device, dtype, seed):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        script_fn = torch.jit.script(ops.sigmoid_focal_loss)
        torch.random.manual_seed(seed)
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
        focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1760
        scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
Aditya Oke's avatar
Aditya Oke committed
1761
1762
1763

        tol = 1e-3 if dtype is torch.half else 1e-5
        torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)
Abhijit Deo's avatar
Abhijit Deo committed
1764

1765
    # Raise ValueError for anonymous reduction mode
1766
    @pytest.mark.parametrize("device", cpu_and_cuda())
1767
1768
1769
1770
1771
1772
1773
1774
1775
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
    def test_reduction_mode(self, device, dtype, reduction="xyz"):
        if device == "cpu" and dtype is torch.half:
            pytest.skip("Currently torch.half is not fully supported on cpu")
        torch.random.manual_seed(0)
        inputs, targets = self._generate_diverse_input_target_pair(device=device, dtype=dtype)
        with pytest.raises(ValueError, match="Invalid"):
            ops.sigmoid_focal_loss(inputs, targets, 0.25, 2, reduction)

Abhijit Deo's avatar
Abhijit Deo committed
1776

1777
1778
class TestMasksToBoxes:
    def test_masks_box(self):
Aditya Oke's avatar
Aditya Oke committed
1779
        def masks_box_check(masks, expected, atol=1e-4):
1780
1781
            out = ops.masks_to_boxes(masks)
            assert out.dtype == torch.float
Aditya Oke's avatar
Aditya Oke committed
1782
            torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=True, atol=atol)
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798

        # Check for int type boxes.
        def _get_image():
            assets_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
            mask_path = os.path.join(assets_directory, "masks.tiff")
            image = Image.open(mask_path)
            return image

        def _create_masks(image, masks):
            for index in range(image.n_frames):
                image.seek(index)
                frame = np.array(image)
                masks[index] = torch.tensor(frame)

            return masks

1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
        expected = torch.tensor(
            [
                [127, 2, 165, 40],
                [2, 50, 44, 92],
                [56, 63, 98, 100],
                [139, 68, 175, 104],
                [160, 112, 198, 145],
                [49, 138, 99, 182],
                [108, 148, 152, 213],
            ],
            dtype=torch.float,
        )
1811
1812
1813
1814
1815
1816
1817
1818

        image = _get_image()
        for dtype in [torch.float16, torch.float32, torch.float64]:
            masks = torch.zeros((image.n_frames, image.height, image.width), dtype=dtype)
            masks = _create_masks(image, masks)
            masks_box_check(masks, expected)


1819
class TestStochasticDepth:
1820
    @pytest.mark.parametrize("seed", range(10))
1821
1822
    @pytest.mark.parametrize("p", [0.2, 0.5, 0.8])
    @pytest.mark.parametrize("mode", ["batch", "row"])
1823
1824
    def test_stochastic_depth_random(self, seed, mode, p):
        torch.manual_seed(seed)
1825
1826
1827
        stats = pytest.importorskip("scipy.stats")
        batch_size = 5
        x = torch.ones(size=(batch_size, 3, 4, 4))
1828
        layer = ops.StochasticDepth(p=p, mode=mode)
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
        layer.__repr__()

        trials = 250
        num_samples = 0
        counts = 0
        for _ in range(trials):
            out = layer(x)
            non_zero_count = out.sum(dim=(1, 2, 3)).nonzero().size(0)
            if mode == "batch":
                if non_zero_count == 0:
                    counts += 1
                num_samples += 1
            elif mode == "row":
                counts += batch_size - non_zero_count
                num_samples += batch_size

1845
        p_value = stats.binomtest(counts, num_samples, p=p).pvalue
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
        assert p_value > 0.01

    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("p", (0, 1))
    @pytest.mark.parametrize("mode", ["batch", "row"])
    def test_stochastic_depth(self, seed, mode, p):
        torch.manual_seed(seed)
        batch_size = 5
        x = torch.ones(size=(batch_size, 3, 4, 4))
        layer = ops.StochasticDepth(p=p, mode=mode)

        out = layer(x)
        if p == 0:
            assert out.equal(x)
        elif p == 1:
            assert out.equal(torch.zeros_like(x))
1862

1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
    def make_obj(self, p, mode, wrap=False):
        obj = ops.StochasticDepth(p, mode)
        return StochasticDepthWrapper(obj) if wrap else obj

    @pytest.mark.parametrize("p", (0, 1))
    @pytest.mark.parametrize("mode", ["batch", "row"])
    def test_is_leaf_node(self, p, mode):
        op_obj = self.make_obj(p, mode, wrap=True)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

1877

1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
class TestUtils:
    @pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
    def test_split_normalization_params(self, norm_layer):
        model = models.mobilenet_v3_large(norm_layer=norm_layer)
        params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])

        assert len(params[0]) == 92
        assert len(params[1]) == 82


1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
class TestDropBlock:
    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("dim", [2, 3])
    @pytest.mark.parametrize("p", [0, 0.5])
    @pytest.mark.parametrize("block_size", [5, 11])
    @pytest.mark.parametrize("inplace", [True, False])
    def test_drop_block(self, seed, dim, p, block_size, inplace):
        torch.manual_seed(seed)
        batch_size = 5
        channels = 3
        height = 11
        width = height
        depth = height
        if dim == 2:
            x = torch.ones(size=(batch_size, channels, height, width))
            layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
            feature_size = height * width
        elif dim == 3:
            x = torch.ones(size=(batch_size, channels, depth, height, width))
            layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
            feature_size = depth * height * width
        layer.__repr__()

        out = layer(x)
        if p == 0:
            assert out.equal(x)
        if block_size == height:
            for b, c in product(range(batch_size), range(channels)):
                assert out[b, c].count_nonzero() in (0, feature_size)

    @pytest.mark.parametrize("seed", range(10))
    @pytest.mark.parametrize("dim", [2, 3])
    @pytest.mark.parametrize("p", [0.1, 0.2])
    @pytest.mark.parametrize("block_size", [3])
    @pytest.mark.parametrize("inplace", [False])
    def test_drop_block_random(self, seed, dim, p, block_size, inplace):
        torch.manual_seed(seed)
        batch_size = 5
        channels = 3
        height = 11
        width = height
        depth = height
        if dim == 2:
            x = torch.ones(size=(batch_size, channels, height, width))
            layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
        elif dim == 3:
            x = torch.ones(size=(batch_size, channels, depth, height, width))
            layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)

        trials = 250
        num_samples = 0
        counts = 0
        cell_numel = torch.tensor(x.shape).prod()
        for _ in range(trials):
            with torch.no_grad():
                out = layer(x)
            non_zero_count = out.nonzero().size(0)
            counts += cell_numel - non_zero_count
            num_samples += cell_numel

        assert abs(p - counts / num_samples) / p < 0.15

    def make_obj(self, dim, p, block_size, inplace, wrap=False):
        if dim == 2:
            obj = ops.DropBlock2d(p, block_size, inplace)
        elif dim == 3:
            obj = ops.DropBlock3d(p, block_size, inplace)
        return DropBlockWrapper(obj) if wrap else obj

    @pytest.mark.parametrize("dim", (2, 3))
    @pytest.mark.parametrize("p", [0, 1])
    @pytest.mark.parametrize("block_size", [5, 7])
    @pytest.mark.parametrize("inplace", [True, False])
    def test_is_leaf_node(self, dim, p, block_size, inplace):
        op_obj = self.make_obj(dim, p, block_size, inplace, wrap=True)
        graph_node_names = get_graph_node_names(op_obj)

        assert len(graph_node_names) == 2
        assert len(graph_node_names[0]) == len(graph_node_names[1])
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs


1970
if __name__ == "__main__":
1971
    pytest.main([__file__])