test_ops.py 45.1 KB
Newer Older
1
from common_utils import needs_cuda, cpu_only
2
from _assert_utils import assert_equal
3
4
import math
import unittest
5
import pytest
6

7
import numpy as np
8

9
import torch
10
from functools import lru_cache
11
from torch import Tensor
12
from torch.autograd import gradcheck
13
from torch.nn.modules.utils import _pair
14
from torchvision import ops
15
from typing import Tuple
16
17


18
class OpTester(object):
19
20
21
22
    @classmethod
    def setUpClass(cls):
        cls.dtype = torch.float64

23
24
    def test_forward_cpu_contiguous(self):
        self._test_forward(device=torch.device('cpu'), contiguous=True)
25

26
27
    def test_forward_cpu_non_contiguous(self):
        self._test_forward(device=torch.device('cpu'), contiguous=False)
28

29
30
    def test_backward_cpu_contiguous(self):
        self._test_backward(device=torch.device('cpu'), contiguous=True)
31

32
33
    def test_backward_cpu_non_contiguous(self):
        self._test_backward(device=torch.device('cpu'), contiguous=False)
34

35
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
36
37
    def test_forward_cuda_contiguous(self):
        self._test_forward(device=torch.device('cuda'), contiguous=True)
38

39
40
41
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
    def test_forward_cuda_non_contiguous(self):
        self._test_forward(device=torch.device('cuda'), contiguous=False)
42

43
44
45
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
    def test_backward_cuda_contiguous(self):
        self._test_backward(device=torch.device('cuda'), contiguous=True)
46
47

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
48
49
50
    def test_backward_cuda_non_contiguous(self):
        self._test_backward(device=torch.device('cuda'), contiguous=False)

51
52
53
54
55
56
57
58
    def _test_forward(self, device, contiguous):
        pass

    def _test_backward(self, device, contiguous):
        pass


class RoIOpTester(OpTester):
59
    def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs):
60
61
        x_dtype = self.dtype if x_dtype is None else x_dtype
        rois_dtype = self.dtype if rois_dtype is None else rois_dtype
62
63
64
        pool_size = 5
        # n_channels % (pool_size ** 2) == 0 required for PS opeartions.
        n_channels = 2 * (pool_size ** 2)
65
        x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
66
67
        if not contiguous:
            x = x.permute(0, 1, 3, 2)
68
69
70
71
        rois = torch.tensor([[0, 0, 0, 9, 9],  # format is (xyxy)
                             [0, 0, 5, 4, 9],
                             [0, 5, 5, 9, 9],
                             [1, 0, 0, 9, 9]],
72
                            dtype=rois_dtype, device=device)
73

74
        pool_h, pool_w = pool_size, pool_size
75
        y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
76
77
        # the following should be true whether we're running an autocast test or not.
        self.assertTrue(y.dtype == x.dtype)
78
        gt_y = self.expected_fn(x, rois, pool_h, pool_w, spatial_scale=1,
79
                                sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs)
80

81
        tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
82
        torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
83
84
85
86
87
88
89
90
91
92

    def _test_backward(self, device, contiguous):
        pool_size = 2
        x = torch.rand(1, 2 * (pool_size ** 2), 5, 5, dtype=self.dtype, device=device, requires_grad=True)
        if not contiguous:
            x = x.permute(0, 1, 3, 2)
        rois = torch.tensor([[0, 0, 0, 4, 4],  # format is (xyxy)
                             [0, 0, 2, 3, 4],
                             [0, 2, 2, 4, 4]],
                            dtype=self.dtype, device=device)
93

94
95
        def func(z):
            return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
96

97
        script_func = self.get_script_fn(rois, pool_size)
98

99
100
        self.assertTrue(gradcheck(func, (x,)))
        self.assertTrue(gradcheck(script_func, (x,)))
101

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    def test_boxes_shape(self):
        self._test_boxes_shape()

    def _helper_boxes_shape(self, func):
        # test boxes as Tensor[N, 5]
        with self.assertRaises(AssertionError):
            a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
            boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
            func(a, boxes, output_size=(2, 2))

        # test boxes as List[Tensor[N, 4]]
        with self.assertRaises(AssertionError):
            a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
            boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
            ops.roi_pool(a, [boxes], output_size=(2, 2))

118
119
    def fn(*args, **kwargs):
        pass
120

121
122
    def get_script_fn(*args, **kwargs):
        pass
123

124
125
    def expected_fn(*args, **kwargs):
        pass
126

127
128
129
130
131
132
133
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
    def test_autocast(self):
        for x_dtype in (torch.float, torch.half):
            for rois_dtype in (torch.float, torch.half):
                with torch.cuda.amp.autocast():
                    self._test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)

134

135
136
137
class RoIPoolTester(RoIOpTester, unittest.TestCase):
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
138

139
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
140
141
        scriped = torch.jit.script(ops.roi_pool)
        return lambda x: scriped(x, rois, pool_size)
142

143
144
145
146
    def expected_fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1,
                    device=None, dtype=torch.float64):
        if device is None:
            device = torch.device("cpu")
147

148
149
        n_channels = x.size(1)
        y = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
150

151
152
        def get_slice(k, block):
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))
153

154
155
156
157
        for roi_idx, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
            roi_x = x[batch_idx, :, i_begin:i_end + 1, j_begin:j_end + 1]
158

159
160
161
            roi_h, roi_w = roi_x.shape[-2:]
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w
162

163
164
165
166
167
168
            for i in range(0, pool_h):
                for j in range(0, pool_w):
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
                    if bin_x.numel() > 0:
                        y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0]
        return y
169

170
171
172
    def _test_boxes_shape(self):
        self._helper_boxes_shape(ops.roi_pool)

173

174
175
176
class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
177

178
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
179
180
        scriped = torch.jit.script(ops.ps_roi_pool)
        return lambda x: scriped(x, rois, pool_size)
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    def expected_fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1,
                    device=None, dtype=torch.float64):
        if device is None:
            device = torch.device("cpu")
        n_input_channels = x.size(1)
        self.assertEqual(n_input_channels % (pool_h * pool_w), 0, "input channels must be divisible by ph * pw")
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
        y = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)

        def get_slice(k, block):
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))

        for roi_idx, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
            roi_x = x[batch_idx, :, i_begin:i_end + 1, j_begin:j_end + 1]

            roi_height = max(i_end - i_begin, 1)
            roi_width = max(j_end - j_begin, 1)
            bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w)

            for i in range(0, pool_h):
                for j in range(0, pool_w):
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
                    if bin_x.numel() > 0:
                        area = bin_x.size(-2) * bin_x.size(-1)
                        for c_out in range(0, n_output_channels):
                            c_in = c_out * (pool_h * pool_w) + pool_w * i + j
                            t = torch.sum(bin_x[c_in, :, :])
                            y[roi_idx, c_out, i, j] = t / area
        return y
213

214
215
216
    def _test_boxes_shape(self):
        self._helper_boxes_shape(ops.ps_roi_pool)

217

218
219
def bilinear_interpolate(data, y, x, snap_border=False):
    height, width = data.shape
220

221
222
223
224
225
    if snap_border:
        if -1 < y <= 0:
            y = 0
        elif height - 1 <= y < height:
            y = height - 1
226

227
228
229
230
        if -1 < x <= 0:
            x = 0
        elif width - 1 <= x < width:
            x = width - 1
231

232
233
234
235
    y_low = int(math.floor(y))
    x_low = int(math.floor(x))
    y_high = y_low + 1
    x_high = x_low + 1
236

237
238
    wy_h = y - y_low
    wx_h = x - x_low
239
    wy_l = 1 - wy_h
240
    wx_l = 1 - wx_h
241

242
    val = 0
243
244
245
246
    for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
        for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
            if 0 <= yp < height and 0 <= xp < width:
                val += wx * wy * data[yp, xp]
247
    return val
248
249


250
class RoIAlignTester(RoIOpTester, unittest.TestCase):
AhnDW's avatar
AhnDW committed
251
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
252
        return ops.RoIAlign((pool_h, pool_w), spatial_scale=spatial_scale,
AhnDW's avatar
AhnDW committed
253
                            sampling_ratio=sampling_ratio, aligned=aligned)(x, rois)
254

255
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
256
257
        scriped = torch.jit.script(ops.roi_align)
        return lambda x: scriped(x, rois, pool_size)
258

AhnDW's avatar
AhnDW committed
259
    def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False,
260
                    device=None, dtype=torch.float64):
261
262
        if device is None:
            device = torch.device("cpu")
263
264
265
        n_channels = in_data.size(1)
        out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)

AhnDW's avatar
AhnDW committed
266
267
        offset = 0.5 if aligned else 0.

268
269
        for r, roi in enumerate(rois):
            batch_idx = int(roi[0])
AhnDW's avatar
AhnDW committed
270
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - offset for x in roi[1:])
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

            roi_h = i_end - i_begin
            roi_w = j_end - j_begin
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w

            for i in range(0, pool_h):
                start_h = i_begin + i * bin_h
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
                for j in range(0, pool_w):
                    start_w = j_begin + j * bin_w
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))

                    for channel in range(0, n_channels):

                        val = 0
                        for iy in range(0, grid_h):
                            y = start_h + (iy + 0.5) * bin_h / grid_h
                            for ix in range(0, grid_w):
                                x = start_w + (ix + 0.5) * bin_w / grid_w
291
                                val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True)
292
293
294
                        val /= grid_h * grid_w

                        out_data[r, channel, i, j] = val
295
296
        return out_data

297
298
299
    def _test_boxes_shape(self):
        self._helper_boxes_shape(ops.roi_align)

300
301
302
303
    def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs):
        for aligned in (True, False):
            super()._test_forward(device, contiguous, x_dtype, rois_dtype, aligned=aligned)

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    def test_qroialign(self):
        """Make sure quantized version of RoIAlign is close to float version"""
        pool_size = 5
        img_size = 10
        n_channels = 2
        num_imgs = 1
        dtype = torch.float

        def make_rois(num_rois=1000):
            rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
            rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,))  # set batch index
            rois[:, 3:] += rois[:, 1:3]  # make sure boxes aren't degenerate
            return rois

        for aligned in (True, False):
            for scale, zero_point in ((1, 0), (2, 10), (0.1, 50)):
                for qdtype in (torch.qint8, torch.quint8, torch.qint32):

                    x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype)
                    qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qdtype)

                    rois = make_rois()
                    qrois = torch.quantize_per_tensor(rois, scale=scale, zero_point=zero_point, dtype=qdtype)

                    x, rois = qx.dequantize(), qrois.dequantize()  # we want to pass the same inputs

                    y = ops.roi_align(
                        x,
                        rois,
                        output_size=pool_size,
                        spatial_scale=1,
                        sampling_ratio=-1,
                        aligned=aligned,
                    )
                    qy = ops.roi_align(
                        qx,
                        qrois,
                        output_size=pool_size,
                        spatial_scale=1,
                        sampling_ratio=-1,
                        aligned=aligned,
                    )

                    # The output qy is itself a quantized tensor and there might have been a loss of info when it was
                    # quantized. For a fair comparison we need to quantize y as well
                    quantized_float_y = torch.quantize_per_tensor(y, scale=scale, zero_point=zero_point, dtype=qdtype)

                    try:
                        # Ideally, we would assert this, which passes with (scale, zero) == (1, 0)
                        self.assertTrue((qy == quantized_float_y).all())
                    except AssertionError:
                        # But because the computation aren't exactly the same between the 2 RoIAlign procedures, some
                        # rounding error may lead to a difference of 2 in the output.
                        # For example with (scale, zero) = (2, 10), 45.00000... will be quantized to 44
                        # but 45.00000001 will be rounded to 46. We make sure below that:
                        # - such discrepancies between qy and quantized_float_y are very rare (less then 5%)
                        # - any difference between qy and quantized_float_y is == scale
                        diff_idx = torch.where(qy != quantized_float_y)
                        num_diff = diff_idx[0].numel()
                        self.assertTrue(num_diff / qy.numel() < .05)

                        abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize())
                        t_scale = torch.full_like(abs_diff, fill_value=scale)
367
                        torch.testing.assert_close(abs_diff, t_scale, rtol=1e-5, atol=1e-5)
368
369
370
371
372
373
374
375

        x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype)
        qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8)
        rois = make_rois(10)
        qrois = torch.quantize_per_tensor(rois, scale=1, zero_point=0, dtype=torch.qint8)
        with self.assertRaisesRegex(RuntimeError, "Only one image per batch is allowed"):
            ops.roi_align(qx, qrois, output_size=pool_size)

376

377
378
379
380
class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale,
                              sampling_ratio=sampling_ratio)(x, rois)
381

382
    def get_script_fn(self, rois, pool_size):
Nicolas Hug's avatar
Nicolas Hug committed
383
384
        scriped = torch.jit.script(ops.ps_roi_align)
        return lambda x: scriped(x, rois, pool_size)
385

386
387
    def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1,
                    sampling_ratio=-1, dtype=torch.float64):
388
389
        if device is None:
            device = torch.device("cpu")
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
        n_input_channels = in_data.size(1)
        self.assertEqual(n_input_channels % (pool_h * pool_w), 0, "input channels must be divisible by ph * pw")
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
        out_data = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)

        for r, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - 0.5 for x in roi[1:])

            roi_h = i_end - i_begin
            roi_w = j_end - j_begin
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w

            for i in range(0, pool_h):
                start_h = i_begin + i * bin_h
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
                for j in range(0, pool_w):
                    start_w = j_begin + j * bin_w
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
                    for c_out in range(0, n_output_channels):
                        c_in = c_out * (pool_h * pool_w) + pool_w * i + j

                        val = 0
                        for iy in range(0, grid_h):
                            y = start_h + (iy + 0.5) * bin_h / grid_h
                            for ix in range(0, grid_w):
                                x = start_w + (ix + 0.5) * bin_w / grid_w
418
                                val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True)
419
420
421
422
                        val /= grid_h * grid_w

                        out_data[r, c_out, i, j] = val
        return out_data
423

424
425
426
    def _test_boxes_shape(self):
        self._helper_boxes_shape(ops.ps_roi_align)

427

428
429
430
431
432
433
434
435
436
437
438
439
440
441
class MultiScaleRoIAlignTester(unittest.TestCase):
    def test_msroialign_repr(self):
        fmap_names = ['0']
        output_size = (7, 7)
        sampling_ratio = 2
        # Pass mock feature map names
        t = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)

        # Check integrity of object __repr__ attribute
        expected_string = (f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, "
                           f"sampling_ratio={sampling_ratio})")
        self.assertEqual(t.__repr__(), expected_string)


442
443
class TestNMS:
    def _reference_nms(self, boxes, scores, iou_threshold):
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
        """
        Args:
            box_scores (N, 5): boxes in corner-form and probabilities.
            iou_threshold: intersection over union threshold.
        Returns:
             picked: a list of indexes of the kept boxes
        """
        picked = []
        _, indexes = scores.sort(descending=True)
        while len(indexes) > 0:
            current = indexes[0]
            picked.append(current.item())
            if len(indexes) == 1:
                break
            current_box = boxes[current, :]
            indexes = indexes[1:]
            rest_boxes = boxes[indexes, :]
            iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1)
            indexes = indexes[iou <= iou_threshold]

        return torch.as_tensor(picked)

466
467
468
469
470
    def _create_tensors_with_iou(self, N, iou_thresh):
        # force last box to have a pre-defined iou with the first box
        # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
        # then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
        # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
471
472
473
        # Adjust the threshold upward a bit with the intent of creating
        # at least one box that exceeds (barely) the threshold and so
        # should be suppressed.
474
        boxes = torch.rand(N, 4) * 100
475
476
477
        boxes[:, 2:] += boxes[:, :2]
        boxes[-1, :] = boxes[0, :]
        x0, y0, x1, y1 = boxes[-1].tolist()
478
        iou_thresh += 1e-5
479
        boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
480
481
482
        scores = torch.rand(N)
        return boxes, scores

483
484
485
    @cpu_only
    @pytest.mark.parametrize("iou", (.2, .5, .8))
    def test_nms_ref(self, iou):
486
        err_msg = 'NMS incompatible between CPU and reference implementation for IoU={}'
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        keep_ref = self._reference_nms(boxes, scores, iou)
        keep = ops.nms(boxes, scores, iou)
        assert torch.allclose(keep, keep_ref), err_msg.format(iou)

    @cpu_only
    def test_nms_input_errors(self):
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(4), torch.rand(3), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 5), torch.rand(3), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
        with pytest.raises(RuntimeError):
            ops.nms(torch.rand(3, 4), torch.rand(4), 0.5)

    @cpu_only
    @pytest.mark.parametrize("iou", (.2, .5, .8))
    @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10)))
    def test_qnms(self, iou, scale, zero_point):
507
508
509
510
        # Note: we compare qnms vs nms instead of qnms vs reference implementation.
        # This is because with the int convertion, the trick used in _create_tensors_with_iou
        # doesn't really work (in fact, nms vs reference implem will also fail with ints)
        err_msg = 'NMS and QNMS give different results for IoU={}'
511
512
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        scores *= 100  # otherwise most scores would be 0 or 1 after int convertion
513

514
515
        qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point, dtype=torch.quint8)
        qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point, dtype=torch.quint8)
516

517
518
        boxes = qboxes.dequantize()
        scores = qscores.dequantize()
519

520
521
        keep = ops.nms(boxes, scores, iou)
        qkeep = ops.nms(qboxes, qscores, iou)
522

523
        assert torch.allclose(qkeep, keep), err_msg.format(iou)
524

525
526
527
    @needs_cuda
    @pytest.mark.parametrize("iou", (.2, .5, .8))
    def test_nms_cuda(self, iou, dtype=torch.float64):
528
        tol = 1e-3 if dtype is torch.half else 1e-5
529
530
        err_msg = 'NMS incompatible between CPU and CUDA for IoU={}'

531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        r_cpu = ops.nms(boxes, scores, iou)
        r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)

        is_eq = torch.allclose(r_cpu, r_cuda.cpu())
        if not is_eq:
            # if the indices are not the same, ensure that it's because the scores
            # are duplicate
            is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol)
        assert is_eq, err_msg.format(iou)

    @needs_cuda
    @pytest.mark.parametrize("iou", (.2, .5, .8))
    @pytest.mark.parametrize("dtype", (torch.float, torch.half))
    def test_autocast(self, iou, dtype):
        with torch.cuda.amp.autocast():
            self.test_nms_cuda(iou=iou, dtype=dtype)

    @needs_cuda
550
551
552
553
554
555
556
557
558
    def test_nms_cuda_float16(self):
        boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
                              [285.1472, 188.7374, 1192.4984, 851.0669],
                              [279.2440, 197.9812, 1189.4746, 849.2019]]).cuda()
        scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda()

        iou_thres = 0.2
        keep32 = ops.nms(boxes, scores, iou_thres)
        keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
559
        assert_equal(keep32, keep16)
560

561
    @cpu_only
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    def test_batched_nms_implementations(self):
        """Make sure that both implementations of batched_nms yield identical results"""

        num_boxes = 1000
        iou_threshold = .9

        boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
        assert max(boxes[:, 0]) < min(boxes[:, 2])  # x1 < x2
        assert max(boxes[:, 1]) < min(boxes[:, 3])  # y1 < y2

        scores = torch.rand(num_boxes)
        idxs = torch.randint(0, 4, size=(num_boxes,))
        keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
        keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)

577
578
579
        torch.testing.assert_close(
            keep_vanilla, keep_trick, msg="The vanilla and the trick implementation yield different nms outputs."
        )
580
581
582

        # Also make sure an empty tensor is returned if boxes is empty
        empty = torch.empty((0,), dtype=torch.int64)
583
        torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None))
584

585

586
class DeformConvTester(OpTester, unittest.TestCase):
587
    def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        stride_h, stride_w = _pair(stride)
        pad_h, pad_w = _pair(padding)
        dil_h, dil_w = _pair(dilation)
        weight_h, weight_w = weight.shape[-2:]

        n_batches, n_in_channels, in_h, in_w = x.shape
        n_out_channels = weight.shape[0]

        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

        n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w)
        in_c_per_offset_grp = n_in_channels // n_offset_grps

        n_weight_grps = n_in_channels // weight.shape[1]
        in_c_per_weight_grp = weight.shape[1]
        out_c_per_weight_grp = n_out_channels // n_weight_grps

        out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
        for b in range(n_batches):
            for c_out in range(n_out_channels):
                for i in range(out_h):
                    for j in range(out_w):
                        for di in range(weight_h):
                            for dj in range(weight_w):
                                for c in range(in_c_per_weight_grp):
                                    weight_grp = c_out // out_c_per_weight_grp
                                    c_in = weight_grp * in_c_per_weight_grp + c

                                    offset_grp = c_in // in_c_per_offset_grp
618
619
                                    mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj
                                    offset_idx = 2 * mask_idx
620
621
622
623

                                    pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
                                    pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]

624
625
626
627
628
                                    mask_value = 1.0
                                    if mask is not None:
                                        mask_value = mask[b, mask_idx, i, j]

                                    out[b, c_out, i, j] += (mask_value * weight[c_out, c, di, dj] *
629
630
631
632
                                                            bilinear_interpolate(x[b, c_in, :, :], pi, pj))
        out += bias.view(1, n_out_channels, 1, 1)
        return out

633
    @lru_cache(maxsize=None)
634
    def get_fn_args(self, device, contiguous, batch_sz, dtype):
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
        n_in_channels = 6
        n_out_channels = 2
        n_weight_grps = 2
        n_offset_grps = 3

        stride = (2, 1)
        pad = (1, 0)
        dilation = (2, 1)

        stride_h, stride_w = stride
        pad_h, pad_w = pad
        dil_h, dil_w = dilation
        weight_h, weight_w = (3, 2)
        in_h, in_w = (5, 4)

        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1

653
        x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True)
654
655

        offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w,
656
                             device=device, dtype=dtype, requires_grad=True)
657

658
659
660
        mask = torch.randn(batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w,
                           device=device, dtype=dtype, requires_grad=True)

661
        weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w,
662
                             device=device, dtype=dtype, requires_grad=True)
663

664
        bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True)
665
666
667
668

        if not contiguous:
            x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
            offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
669
            mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
670
671
            weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)

672
        return x, weight, offset, mask, bias, stride, pad, dilation
673

674
675
    def _test_forward(self, device, contiguous, dtype=None):
        dtype = self.dtype if dtype is None else dtype
676
        for batch_sz in [0, 33]:
677
            self._test_forward_with_batchsize(device, contiguous, batch_sz, dtype)
678

679
    def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype):
680
        x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
681
682
683
684
        in_channels = 6
        out_channels = 2
        kernel_size = (3, 2)
        groups = 2
Nicolas Hug's avatar
Nicolas Hug committed
685
        tol = 2e-3 if dtype is torch.half else 1e-5
686
687

        layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
688
                                 dilation=dilation, groups=groups).to(device=x.device, dtype=dtype)
689
        res = layer(x, offset, mask)
690
691
692

        weight = layer.weight.data
        bias = layer.bias.data
693
694
        expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)

695
696
697
        torch.testing.assert_close(
            res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected)
        )
698
699
700
701

        # no modulation test
        res = layer(x, offset)
        expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
702

703
704
705
        torch.testing.assert_close(
            res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected)
        )
706

707
708
709
710
711
        # test for wrong sizes
        with self.assertRaises(RuntimeError):
            wrong_offset = torch.rand_like(offset[:, :2])
            res = layer(x, wrong_offset)

712
713
714
715
        with self.assertRaises(RuntimeError):
            wrong_mask = torch.rand_like(mask[:, :2])
            res = layer(x, offset, wrong_mask)

716
    def _test_backward(self, device, contiguous):
717
718
719
720
        for batch_sz in [0, 33]:
            self._test_backward_with_batchsize(device, contiguous, batch_sz)

    def _test_backward_with_batchsize(self, device, contiguous, batch_sz):
721
722
723
724
725
726
        x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(device, contiguous,
                                                                                    batch_sz, self.dtype)

        def func(x_, offset_, mask_, weight_, bias_):
            return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride,
                                     padding=padding, dilation=dilation, mask=mask_)
727

728
        gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True)
729
730
731
732
733

        def func_no_mask(x_, offset_, weight_, bias_):
            return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride,
                                     padding=padding, dilation=dilation, mask=None)

734
        gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True)
735
736
737
738
739
740

        @torch.jit.script
        def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
            # type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
            return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_,
                                     padding=pad_, dilation=dilation_, mask=mask_)
741

742
        gradcheck(lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
743
                  (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True)
744
745

        @torch.jit.script
746
747
748
749
        def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
            # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
            return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_,
                                     padding=pad_, dilation=dilation_, mask=None)
750

751
        gradcheck(lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
752
                  (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True)
753

754
755
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
    def test_compare_cpu_cuda_grads(self):
756
757
        # Test from https://github.com/pytorch/vision/issues/2598
        # Run on CUDA only
758
        for contiguous in [False, True]:
759
760
761
762
763
764
            # compare grads computed on CUDA with grads computed on CPU
            true_cpu_grads = None

            init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
            img = torch.randn(8, 9, 1000, 110)
            offset = torch.rand(8, 2 * 3 * 3, 1000, 110)
765
            mask = torch.rand(8, 3 * 3, 1000, 110)
766
767
768
769

            if not contiguous:
                img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
                offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
770
                mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
771
772
773
774
775
776
                weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
            else:
                weight = init_weight

            for d in ["cpu", "cuda"]:

777
                out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
778
779
780
781
782
783
784
                out.mean().backward()
                if true_cpu_grads is None:
                    true_cpu_grads = init_weight.grad
                    self.assertTrue(true_cpu_grads is not None)
                else:
                    self.assertTrue(init_weight.grad is not None)
                    res_grads = init_weight.grad.to("cpu")
785
                    torch.testing.assert_close(true_cpu_grads, res_grads)
786

787
788
789
790
791
792
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
    def test_autocast(self):
        for dtype in (torch.float, torch.half):
            with torch.cuda.amp.autocast():
                self._test_forward(torch.device("cuda"), False, dtype=dtype)

793

794
795
796
class FrozenBNTester(unittest.TestCase):
    def test_frozenbatchnorm2d_repr(self):
        num_features = 32
797
798
        eps = 1e-5
        t = ops.misc.FrozenBatchNorm2d(num_features, eps=eps)
799
800

        # Check integrity of object __repr__ attribute
801
        expected_string = f"FrozenBatchNorm2d({num_features}, eps={eps})"
802
803
        self.assertEqual(t.__repr__(), expected_string)

804
805
806
807
808
809
810
811
812
    def test_frozenbatchnorm2d_eps(self):
        sample_size = (4, 32, 28, 28)
        x = torch.rand(sample_size)
        state_dict = dict(weight=torch.rand(sample_size[1]),
                          bias=torch.rand(sample_size[1]),
                          running_mean=torch.rand(sample_size[1]),
                          running_var=torch.rand(sample_size[1]),
                          num_batches_tracked=torch.tensor(100))

813
        # Check that default eps is equal to the one of BN
814
815
        fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
        fbn.load_state_dict(state_dict, strict=False)
816
        bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
817
818
        bn.load_state_dict(state_dict)
        # Difference is expected to fall in an acceptable range
819
        torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
820
821
822
823
824
825

        # Check computation for eps > 0
        fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
        fbn.load_state_dict(state_dict, strict=False)
        bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
        bn.load_state_dict(state_dict)
826
        torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
827
828
829
830
831
832

    def test_frozenbatchnorm2d_n_arg(self):
        """Ensure a warning is thrown when passing `n` kwarg
        (remove this when support of `n` is dropped)"""
        self.assertWarns(DeprecationWarning, ops.misc.FrozenBatchNorm2d, 32, eps=1e-5, n=32)

833

834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
class BoxConversionTester(unittest.TestCase):
    @staticmethod
    def _get_box_sequences():
        # Define here the argument type of `boxes` supported by region pooling operations
        box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float)
        box_list = [torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
                    torch.tensor([[0, 0, 100, 100]], dtype=torch.float)]
        box_tuple = tuple(box_list)
        return box_tensor, box_list, box_tuple

    def test_check_roi_boxes_shape(self):
        # Ensure common sequences of tensors are supported
        for box_sequence in self._get_box_sequences():
            self.assertIsNone(ops._utils.check_roi_boxes_shape(box_sequence))

    def test_convert_boxes_to_roi_format(self):
        # Ensure common sequences of tensors yield the same result
        ref_tensor = None
        for box_sequence in self._get_box_sequences():
            if ref_tensor is None:
                ref_tensor = box_sequence
            else:
                self.assertTrue(torch.equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence)))


859
860
861
862
863
864
865
866
class BoxTester(unittest.TestCase):
    def test_bbox_same(self):
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
                                  [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)

        exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
                                [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)

867
868
869
870
        assert exp_xyxy.size() == torch.Size([4, 4])
        assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy)
        assert_equal(ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh"), exp_xyxy)
        assert_equal(ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh"), exp_xyxy)
871
872
873
874
875
876
877
878
879

    def test_bbox_xyxy_xywh(self):
        # Simple test convert boxes to xywh and back. Make sure they are same.
        # box_tensor is in x1 y1 x2 y2 format.
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
                                  [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
        exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
                                [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)

880
        assert exp_xywh.size() == torch.Size([4, 4])
881
        box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
882
        assert_equal(box_xywh, exp_xywh)
883
884
885

        # Reverse conversion
        box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy")
886
        assert_equal(box_xyxy, box_tensor)
887
888
889
890
891
892
893
894
895

    def test_bbox_xyxy_cxcywh(self):
        # Simple test convert boxes to xywh and back. Make sure they are same.
        # box_tensor is in x1 y1 x2 y2 format.
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
                                  [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
        exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0],
                                  [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float)

896
        assert exp_cxcywh.size() == torch.Size([4, 4])
897
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
898
        assert_equal(box_cxcywh, exp_cxcywh)
899
900
901

        # Reverse conversion
        box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy")
902
        assert_equal(box_xyxy, box_tensor)
903
904
905
906
907
908
909
910
911

    def test_bbox_xywh_cxcywh(self):
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
                                  [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)

        # This is wrong
        exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0],
                                  [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float)

912
        assert exp_cxcywh.size() == torch.Size([4, 4])
913
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh")
914
        assert_equal(box_cxcywh, exp_cxcywh)
915
916
917

        # Reverse conversion
        box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
918
        assert_equal(box_xywh, box_tensor)
919

920
921
922
923
924
925
926
927
928
929
930
931
932
    def test_bbox_invalid(self):
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
                                  [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)

        invalid_infmts = ["xwyh", "cxwyh"]
        invalid_outfmts = ["xwcx", "xhwcy"]
        for inv_infmt in invalid_infmts:
            for inv_outfmt in invalid_outfmts:
                self.assertRaises(ValueError, ops.box_convert, box_tensor, inv_infmt, inv_outfmt)

    def test_bbox_convert_jit(self):
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
                                  [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
933

934
935
        scripted_fn = torch.jit.script(ops.box_convert)
        TOLERANCE = 1e-3
936

937
938
        box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
        scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh')
939
        torch.testing.assert_close(scripted_xywh, box_xywh, rtol=0.0, atol=TOLERANCE)
940

941
942
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
        scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh')
943
        torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE)
944
945


Aditya Oke's avatar
Aditya Oke committed
946
947
class BoxAreaTester(unittest.TestCase):
    def test_box_area(self):
948
949
        def area_check(box, expected, tolerance=1e-4):
            out = ops.box_area(box)
950
            torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971

        # Check for int boxes
        for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
            box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
            expected = torch.tensor([10000, 0])
            area_check(box_tensor, expected)

        # Check for float32 and float64 boxes
        for dtype in [torch.float32, torch.float64]:
            box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
                                       [285.1472, 188.7374, 1192.4984, 851.0669],
                                       [279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
            expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64)
            area_check(box_tensor, expected, tolerance=0.05)

        # Check for float16 box
        box_tensor = torch.tensor([[285.25, 185.625, 1194.0, 851.5],
                                   [285.25, 188.75, 1192.0, 851.0],
                                   [279.25, 198.0, 1189.0, 849.0]], dtype=torch.float16)
        expected = torch.tensor([605113.875, 600495.1875, 592247.25])
        area_check(box_tensor, expected)
Aditya Oke's avatar
Aditya Oke committed
972
973
974
975


class BoxIouTester(unittest.TestCase):
    def test_iou(self):
976
977
        def iou_check(box, expected, tolerance=1e-4):
            out = ops.box_iou(box, box)
978
            torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
979
980
981
982
983
984
985
986
987
988
989
990
991
992

        # Check for int boxes
        for dtype in [torch.int16, torch.int32, torch.int64]:
            box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype)
            expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]])
            iou_check(box, expected)

        # Check for float boxes
        for dtype in [torch.float16, torch.float32, torch.float64]:
            box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
                                       [285.1472, 188.7374, 1192.4984, 851.0669],
                                       [279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
            expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
            iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-4)
Aditya Oke's avatar
Aditya Oke committed
993
994
995
996


class GenBoxIouTester(unittest.TestCase):
    def test_gen_iou(self):
997
998
        def gen_iou_check(box, expected, tolerance=1e-4):
            out = ops.generalized_box_iou(box, box)
999
            torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013

        # Check for int boxes
        for dtype in [torch.int16, torch.int32, torch.int64]:
            box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype)
            expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]])
            gen_iou_check(box, expected)

        # Check for float boxes
        for dtype in [torch.float16, torch.float32, torch.float64]:
            box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
                                       [285.1472, 188.7374, 1192.4984, 851.0669],
                                       [279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
            expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
            gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)
Aditya Oke's avatar
Aditya Oke committed
1014
1015


1016
1017
if __name__ == '__main__':
    unittest.main()