test_ops.py 15.3 KB
Newer Older
1
from __future__ import division
2
3
4
5
6
7
8
9
10
11
import numpy as np
import torch
from torch.autograd import gradcheck

from torchvision import ops

from itertools import product
import unittest


12
class RoIOpTester(object):
13
14
15
16
    @classmethod
    def setUpClass(cls):
        cls.dtype = torch.float64

17
18
    def test_forward_cpu_contiguous(self):
        self._test_forward(device=torch.device('cpu'), contiguous=True)
19

20
21
    def test_forward_cpu_non_contiguous(self):
        self._test_forward(device=torch.device('cpu'), contiguous=False)
22

23
24
    def test_backward_cpu_contiguous(self):
        self._test_backward(device=torch.device('cpu'), contiguous=True)
25

26
27
    def test_backward_cpu_non_contiguous(self):
        self._test_backward(device=torch.device('cpu'), contiguous=False)
28

29
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
30
31
    def test_forward_cuda_contiguous(self):
        self._test_forward(device=torch.device('cuda'), contiguous=True)
32

33
34
35
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
    def test_forward_cuda_non_contiguous(self):
        self._test_forward(device=torch.device('cuda'), contiguous=False)
36

37
38
39
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
    def test_backward_cuda_contiguous(self):
        self._test_backward(device=torch.device('cuda'), contiguous=True)
40
41

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
42
43
44
45
46
47
48
49
50
51
    def test_backward_cuda_non_contiguous(self):
        self._test_backward(device=torch.device('cuda'), contiguous=False)

    def _test_forward(self, device, contiguous):
        pool_size = 5
        # n_channels % (pool_size ** 2) == 0 required for PS opeartions.
        n_channels = 2 * (pool_size ** 2)
        x = torch.rand(2, n_channels, 10, 10, dtype=self.dtype, device=device)
        if not contiguous:
            x = x.permute(0, 1, 3, 2)
52
53
54
55
56
57
        rois = torch.tensor([[0, 0, 0, 9, 9],  # format is (xyxy)
                             [0, 0, 5, 4, 9],
                             [0, 5, 5, 9, 9],
                             [1, 0, 0, 9, 9]],
                            dtype=self.dtype, device=device)

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        pool_h, pool_w = pool_size, pool_size
        y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1)
        gt_y = self.expected_fn(x, rois, pool_h, pool_w, spatial_scale=1,
                                sampling_ratio=-1, device=device, dtype=self.dtype)

        self.assertTrue(torch.allclose(gt_y, y))

    def _test_backward(self, device, contiguous):
        pool_size = 2
        x = torch.rand(1, 2 * (pool_size ** 2), 5, 5, dtype=self.dtype, device=device, requires_grad=True)
        if not contiguous:
            x = x.permute(0, 1, 3, 2)
        rois = torch.tensor([[0, 0, 0, 4, 4],  # format is (xyxy)
                             [0, 0, 2, 3, 4],
                             [0, 2, 2, 4, 4]],
                            dtype=self.dtype, device=device)
74

75
76
        def func(z):
            return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
77

78
        script_func = self.get_script_fn(rois, pool_size)
79

80
81
82
        self.assertTrue(gradcheck(func, (x,)))
        self.assertTrue(gradcheck(script_func, (x,)))
        return
83

84
85
    def fn(*args, **kwargs):
        pass
86

87
88
    def get_script_fn(*args, **kwargs):
        pass
89

90
91
    def expected_fn(*args, **kwargs):
        pass
92
93


94
95
96
class RoIPoolTester(RoIOpTester, unittest.TestCase):
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
97

98
    def get_script_fn(self, rois, pool_size):
99
        @torch.jit.script
100
101
102
103
        def script_fn(input, rois, pool_size):
            # type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
            return ops.roi_pool(input, rois, pool_size, 1.0)[0]
        return lambda x: script_fn(x, rois, pool_size)
104

105
106
107
108
    def expected_fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1,
                    device=None, dtype=torch.float64):
        if device is None:
            device = torch.device("cpu")
109

110
111
        n_channels = x.size(1)
        y = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
112

113
114
        def get_slice(k, block):
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))
115

116
117
118
119
        for roi_idx, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
            roi_x = x[batch_idx, :, i_begin:i_end + 1, j_begin:j_end + 1]
120

121
122
123
            roi_h, roi_w = roi_x.shape[-2:]
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w
124

125
126
127
128
129
130
            for i in range(0, pool_h):
                for j in range(0, pool_w):
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
                    if bin_x.numel() > 0:
                        y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0]
        return y
131
132


133
134
135
class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
136

137
    def get_script_fn(self, rois, pool_size):
138
        @torch.jit.script
139
140
141
142
        def script_fn(input, rois, pool_size):
            # type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
            return ops.ps_roi_pool(input, rois, pool_size, 1.0)[0]
        return lambda x: script_fn(x, rois, pool_size)
143

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    def expected_fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1,
                    device=None, dtype=torch.float64):
        if device is None:
            device = torch.device("cpu")
        n_input_channels = x.size(1)
        self.assertEqual(n_input_channels % (pool_h * pool_w), 0, "input channels must be divisible by ph * pw")
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
        y = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)

        def get_slice(k, block):
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))

        for roi_idx, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
            roi_x = x[batch_idx, :, i_begin:i_end + 1, j_begin:j_end + 1]

            roi_height = max(i_end - i_begin, 1)
            roi_width = max(j_end - j_begin, 1)
            bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w)

            for i in range(0, pool_h):
                for j in range(0, pool_w):
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
                    if bin_x.numel() > 0:
                        area = bin_x.size(-2) * bin_x.size(-1)
                        for c_out in range(0, n_output_channels):
                            c_in = c_out * (pool_h * pool_w) + pool_w * i + j
                            t = torch.sum(bin_x[c_in, :, :])
                            y[roi_idx, c_out, i, j] = t / area
        return y
175

176

177
178
179
180
def bilinear_interpolate(data, height, width, y, x):
    if y < -1.0 or y > height or x < -1.0 or x > width:
        return 0.

181
182
    y = min(max(0, y), height - 1)
    x = min(max(0, x), width - 1)
183

184
185
    y_low = int(y)
    y_high = min(y_low + 1, height - 1)
186

187
188
    x_low = int(x)
    x_high = min(x_low + 1, width - 1)
189

190
191
    wy_h = y - y_low
    wy_l = 1 - wy_h
192

193
194
    wx_h = x - x_low
    wx_l = 1 - wx_h
195

196
197
198
199
200
    val = 0
    for wx, x in zip((wx_l, wx_h), (x_low, x_high)):
        for wy, y in zip((wy_l, wy_h), (y_low, y_high)):
            val += wx * wy * data[y * width + x]
    return val
201
202


203
204
205
206
class RoIAlignTester(RoIOpTester, unittest.TestCase):
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.RoIAlign((pool_h, pool_w), spatial_scale=spatial_scale,
                            sampling_ratio=sampling_ratio)(x, rois)
207

208
209
210
211
212
213
    def get_script_fn(self, rois, pool_size):
        @torch.jit.script
        def script_fn(input, rois, pool_size):
            # type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
            return ops.roi_align(input, rois, pool_size, 1.0)[0]
        return lambda x: script_fn(x, rois, pool_size)
214

215
216
    def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1,
                    device=None, dtype=torch.float64):
217
218
        if device is None:
            device = torch.device("cpu")
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        n_channels = in_data.size(1)
        out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)

        for r, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale for x in roi[1:])

            roi_h = i_end - i_begin
            roi_w = j_end - j_begin
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w

            for i in range(0, pool_h):
                start_h = i_begin + i * bin_h
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
                for j in range(0, pool_w):
                    start_w = j_begin + j * bin_w
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))

                    for channel in range(0, n_channels):

                        val = 0
                        for iy in range(0, grid_h):
                            y = start_h + (iy + 0.5) * bin_h / grid_h
                            for ix in range(0, grid_w):
                                x = start_w + (ix + 0.5) * bin_w / grid_w
                                val += bilinear_interpolate(
                                    in_data[batch_idx, channel, :, :].flatten(),
                                    in_data.size(-2),
                                    in_data.size(-1),
                                    y, x
                                )
                        val /= grid_h * grid_w

                        out_data[r, channel, i, j] = val
254
255
256
        return out_data


257
258
259
260
class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
        return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale,
                              sampling_ratio=sampling_ratio)(x, rois)
261

262
    def get_script_fn(self, rois, pool_size):
263
        @torch.jit.script
264
265
266
267
        def script_fn(input, rois, pool_size):
            # type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
            return ops.ps_roi_align(input, rois, pool_size, 1.0)[0]
        return lambda x: script_fn(x, rois, pool_size)
268

269
270
    def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1,
                    sampling_ratio=-1, dtype=torch.float64):
271
272
        if device is None:
            device = torch.device("cpu")
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
        n_input_channels = in_data.size(1)
        self.assertEqual(n_input_channels % (pool_h * pool_w), 0, "input channels must be divisible by ph * pw")
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
        out_data = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)

        for r, roi in enumerate(rois):
            batch_idx = int(roi[0])
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - 0.5 for x in roi[1:])

            roi_h = i_end - i_begin
            roi_w = j_end - j_begin
            bin_h = roi_h / pool_h
            bin_w = roi_w / pool_w

            for i in range(0, pool_h):
                start_h = i_begin + i * bin_h
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
                for j in range(0, pool_w):
                    start_w = j_begin + j * bin_w
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
                    for c_out in range(0, n_output_channels):
                        c_in = c_out * (pool_h * pool_w) + pool_w * i + j

                        val = 0
                        for iy in range(0, grid_h):
                            y = start_h + (iy + 0.5) * bin_h / grid_h
                            for ix in range(0, grid_w):
                                x = start_w + (ix + 0.5) * bin_w / grid_w
                                val += bilinear_interpolate(
                                    in_data[batch_idx, c_in, :, :].flatten(),
                                    in_data.size(-2),
                                    in_data.size(-1),
                                    y, x
                                )
                        val /= grid_h * grid_w

                        out_data[r, c_out, i, j] = val
        return out_data
311
312


313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
class NMSTester(unittest.TestCase):
    def reference_nms(self, boxes, scores, iou_threshold):
        """
        Args:
            box_scores (N, 5): boxes in corner-form and probabilities.
            iou_threshold: intersection over union threshold.
        Returns:
             picked: a list of indexes of the kept boxes
        """
        picked = []
        _, indexes = scores.sort(descending=True)
        while len(indexes) > 0:
            current = indexes[0]
            picked.append(current.item())
            if len(indexes) == 1:
                break
            current_box = boxes[current, :]
            indexes = indexes[1:]
            rest_boxes = boxes[indexes, :]
            iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1)
            indexes = indexes[iou <= iou_threshold]

        return torch.as_tensor(picked)

337
338
339
340
341
    def _create_tensors_with_iou(self, N, iou_thresh):
        # force last box to have a pre-defined iou with the first box
        # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
        # then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
        # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
342
        boxes = torch.rand(N, 4) * 100
343
344
345
346
        boxes[:, 2:] += boxes[:, :2]
        boxes[-1, :] = boxes[0, :]
        x0, y0, x1, y1 = boxes[-1].tolist()
        boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
347
348
349
350
351
352
        scores = torch.rand(N)
        return boxes, scores

    def test_nms(self):
        err_msg = 'NMS incompatible between CPU and reference implementation for IoU={}'
        for iou in [0.2, 0.5, 0.8]:
353
            boxes, scores = self._create_tensors_with_iou(1000, iou)
354
355
            keep_ref = self.reference_nms(boxes, scores, iou)
            keep = ops.nms(boxes, scores, iou)
356
            self.assertTrue(torch.allclose(keep, keep_ref), err_msg.format(iou))
357
358
359
360
361
362

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
    def test_nms_cuda(self):
        err_msg = 'NMS incompatible between CPU and CUDA for IoU={}'

        for iou in [0.2, 0.5, 0.8]:
363
            boxes, scores = self._create_tensors_with_iou(1000, iou)
364
365
366
            r_cpu = ops.nms(boxes, scores, iou)
            r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)

367
            self.assertTrue(torch.allclose(r_cpu, r_cuda.cpu()), err_msg.format(iou))
368
369


eellison's avatar
eellison committed
370
371
372
373
374
375
376
377
378
class NewEmptyTensorTester(unittest.TestCase):
    def test_new_empty_tensor(self):
        input = torch.tensor([2., 2.], requires_grad=True)
        new_shape = [3, 3]
        out = torch.ops.torchvision._new_empty_tensor_op(input, new_shape)
        assert out.size() == torch.Size([3, 3])
        assert out.dtype == input.dtype


379
380
if __name__ == '__main__':
    unittest.main()