test_deform_roi_pool.py 5.59 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
import os

import numpy as np
5
import pytest
6
7
import torch

8
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
9

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
_USING_PARROTS = True
try:
    from parrots.autograd import gradcheck
except ImportError:
    from torch.autograd import gradcheck
    _USING_PARROTS = False

cur_dir = os.path.dirname(os.path.abspath(__file__))

inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]),
          ([[[[1., 2.], [3., 4.]], [[4., 3.], [2.,
                                               1.]]]], [[0., 0., 0., 1., 1.]]),
          ([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
              [11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])]
outputs = [([[[[1, 1.25], [1.5, 1.75]]]], [[[[3.0625, 0.4375],
                                             [0.4375, 0.0625]]]]),
           ([[[[1., 1.25], [1.5, 1.75]], [[4, 3.75],
                                          [3.5, 3.25]]]], [[[[3.0625, 0.4375],
                                                             [0.4375, 0.0625]],
                                                            [[3.0625, 0.4375],
                                                             [0.4375,
                                                              0.0625]]]]),
           ([[[[1.9375, 4.75],
               [7.5625,
                10.375]]]], [[[[0.47265625, 0.4296875, 0.4296875, 0.04296875],
                               [0.4296875, 0.390625, 0.390625, 0.0390625],
                               [0.4296875, 0.390625, 0.390625, 0.0390625],
                               [0.04296875, 0.0390625, 0.0390625,
                                0.00390625]]]])]


41
class TestDeformRoIPool:
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

    def test_deform_roi_pool_gradcheck(self):
        if not torch.cuda.is_available():
            return
        from mmcv.ops import DeformRoIPoolPack
        pool_h = 2
        pool_w = 2
        spatial_scale = 1.0
        sampling_ratio = 2

        for case in inputs:
            np_input = np.array(case[0])
            np_rois = np.array(case[1])

            x = torch.tensor(
                np_input, device='cuda', dtype=torch.float, requires_grad=True)
            rois = torch.tensor(np_rois, device='cuda', dtype=torch.float)
            output_c = x.size(1)

            droipool = DeformRoIPoolPack((pool_h, pool_w),
                                         output_c,
                                         spatial_scale=spatial_scale,
                                         sampling_ratio=sampling_ratio).cuda()

            if _USING_PARROTS:
67
                gradcheck(droipool, (x, rois), no_grads=[rois])
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
            else:
                gradcheck(droipool, (x, rois), eps=1e-2, atol=1e-2)

    def test_modulated_deform_roi_pool_gradcheck(self):
        if not torch.cuda.is_available():
            return
        from mmcv.ops import ModulatedDeformRoIPoolPack
        pool_h = 2
        pool_w = 2
        spatial_scale = 1.0
        sampling_ratio = 2

        for case in inputs:
            np_input = np.array(case[0])
            np_rois = np.array(case[1])

            x = torch.tensor(
                np_input, device='cuda', dtype=torch.float, requires_grad=True)
            rois = torch.tensor(np_rois, device='cuda', dtype=torch.float)
            output_c = x.size(1)

            droipool = ModulatedDeformRoIPoolPack(
                (pool_h, pool_w),
                output_c,
                spatial_scale=spatial_scale,
                sampling_ratio=sampling_ratio).cuda()

            if _USING_PARROTS:
96
                gradcheck(droipool, (x, rois), no_grads=[rois])
97
98
            else:
                gradcheck(droipool, (x, rois), eps=1e-2, atol=1e-2)
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

    def _test_deform_roi_pool_allclose(self, device, dtype=torch.float):
        from mmcv.ops import DeformRoIPoolPack
        pool_h = 2
        pool_w = 2
        spatial_scale = 1.0
        sampling_ratio = 2

        for case, output in zip(inputs, outputs):
            np_input = np.array(case[0])
            np_rois = np.array(case[1])
            np_output = np.array(output[0])
            np_grad = np.array(output[1])

            x = torch.tensor(
                np_input, device=device, dtype=torch.float, requires_grad=True)
            rois = torch.tensor(np_rois, device=device, dtype=torch.float)
            output_c = x.size(1)
            droipool = DeformRoIPoolPack(
                (pool_h, pool_w),
                output_c,
                spatial_scale=spatial_scale,
                sampling_ratio=sampling_ratio).to(device)

            output = droipool(x, rois)
            output.backward(torch.ones_like(output))
            assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
            assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3)

    @pytest.mark.parametrize('device', [
129
130
131
132
        pytest.param(
            'npu',
            marks=pytest.mark.skipif(
                not IS_NPU_AVAILABLE, reason='requires NPU support')),
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        pytest.param(
            'cuda',
            marks=pytest.mark.skipif(
                not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
        pytest.param(
            'mlu',
            marks=pytest.mark.skipif(
                not IS_MLU_AVAILABLE, reason='requires MLU support'))
    ])
    @pytest.mark.parametrize('dtype', [
        torch.float,
        pytest.param(
            torch.double,
            marks=pytest.mark.skipif(
                IS_MLU_AVAILABLE,
                reason='MLU does not support for 64-bit floating point')),
        torch.half
    ])
    def test_deform_roi_pool_allclose(self, device, dtype):
        self._test_deform_roi_pool_allclose(device, dtype)