test_roi_align_rotated.py 5.24 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
5
import numpy as np
import pytest
import torch

6
7
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
_USING_PARROTS = True
try:
    from parrots.autograd import gradcheck
except ImportError:
    from torch.autograd import gradcheck
    _USING_PARROTS = False
# yapf:disable
inputs = [([[[[1., 2.], [3., 4.]]]],
           [[0., 0.5, 0.5, 1., 1., 0]]),
          ([[[[1., 2.], [3., 4.]]]],
           [[0., 0.5, 0.5, 1., 1., np.pi / 2]]),
          ([[[[1., 2.], [3., 4.]],
             [[4., 3.], [2., 1.]]]],
           [[0., 0.5, 0.5, 1., 1., 0]]),
          ([[[[1., 2., 5., 6.], [3., 4., 7., 8.],
              [9., 10., 13., 14.], [11., 12., 15., 16.]]]],
           [[0., 1.5, 1.5, 3., 3., 0]]),
          ([[[[1., 2., 5., 6.], [3., 4., 7., 8.],
              [9., 10., 13., 14.], [11., 12., 15., 16.]]]],
           [[0., 1.5, 1.5, 3., 3., np.pi / 2]])]
outputs = [([[[[1.0, 1.25], [1.5, 1.75]]]],
            [[[[3.0625, 0.4375], [0.4375, 0.0625]]]]),
           ([[[[1.5, 1], [1.75, 1.25]]]],
            [[[[3.0625, 0.4375], [0.4375, 0.0625]]]]),
           ([[[[1.0, 1.25], [1.5, 1.75]],
              [[4.0, 3.75], [3.5, 3.25]]]],
            [[[[3.0625, 0.4375], [0.4375, 0.0625]],
              [[3.0625, 0.4375], [0.4375, 0.0625]]]]),
           ([[[[1.9375, 4.75], [7.5625, 10.375]]]],
            [[[[0.47265625, 0.42968750, 0.42968750, 0.04296875],
               [0.42968750, 0.39062500, 0.39062500, 0.03906250],
               [0.42968750, 0.39062500, 0.39062500, 0.03906250],
               [0.04296875, 0.03906250, 0.03906250, 0.00390625]]]]),
           ([[[[7.5625, 1.9375], [10.375, 4.75]]]],
            [[[[0.47265625, 0.42968750, 0.42968750, 0.04296875],
               [0.42968750, 0.39062500, 0.39062500, 0.03906250],
               [0.42968750, 0.39062500, 0.39062500, 0.03906250],
               [0.04296875, 0.03906250, 0.03906250, 0.00390625]]]])]
# yapf:enable

pool_h = 2
pool_w = 2
spatial_scale = 1.0
sampling_ratio = 2


def _test_roialign_rotated_gradcheck(device, dtype):
    try:
        from mmcv.ops import RoIAlignRotated
    except ModuleNotFoundError:
        pytest.skip('RoIAlignRotated op is not successfully compiled')
    if dtype is torch.half:
        pytest.skip('grad check does not support fp16')
    for case in inputs:
        np_input = np.array(case[0])
        np_rois = np.array(case[1])

        x = torch.tensor(
            np_input, dtype=dtype, device=device, requires_grad=True)
        rois = torch.tensor(np_rois, dtype=dtype, device=device)

        froipool = RoIAlignRotated((pool_h, pool_w), spatial_scale,
                                   sampling_ratio)
        if torch.__version__ == 'parrots':
            gradcheck(
                froipool, (x, rois), no_grads=[rois], delta=1e-5, pt_atol=1e-5)
        else:
            gradcheck(froipool, (x, rois), eps=1e-5, atol=1e-5)


def _test_roialign_rotated_allclose(device, dtype):
    try:
80
        from mmcv.ops import RoIAlignRotated, roi_align_rotated
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    except ModuleNotFoundError:
        pytest.skip('test requires compilation')
    pool_h = 2
    pool_w = 2
    spatial_scale = 1.0
    sampling_ratio = 2

    for case, output in zip(inputs, outputs):
        np_input = np.array(case[0])
        np_rois = np.array(case[1])
        np_output = np.array(output[0])
        np_grad = np.array(output[1])

        x = torch.tensor(
            np_input, dtype=dtype, device=device, requires_grad=True)
        rois = torch.tensor(np_rois, dtype=dtype, device=device)

        output = roi_align_rotated(x, rois, (pool_h, pool_w), spatial_scale,
                                   sampling_ratio, True)
        output.backward(torch.ones_like(output))
        assert np.allclose(
            output.data.type(torch.float).cpu().numpy(), np_output, atol=1e-3)
        assert np.allclose(
            x.grad.data.type(torch.float).cpu().numpy(), np_grad, atol=1e-3)

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    # Test deprecated parameters
    roi_align_rotated_module_deprecated = RoIAlignRotated(
        out_size=(pool_h, pool_w),
        spatial_scale=spatial_scale,
        sample_num=sampling_ratio)

    output_1 = roi_align_rotated_module_deprecated(x, rois)

    roi_align_rotated_module_new = RoIAlignRotated(
        output_size=(pool_h, pool_w),
        spatial_scale=spatial_scale,
        sampling_ratio=sampling_ratio)

    output_2 = roi_align_rotated_module_new(x, rois)

    assert np.allclose(
        output_1.data.type(torch.float).cpu().numpy(),
        output_2.data.type(torch.float).cpu().numpy())

125

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
@pytest.mark.parametrize('device', [
    'cpu',
    pytest.param(
        'cuda',
        marks=pytest.mark.skipif(
            not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
    pytest.param(
        'mlu',
        marks=pytest.mark.skipif(
            not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
@pytest.mark.parametrize('dtype', [
    torch.float,
    pytest.param(
        torch.double,
        marks=pytest.mark.skipif(
            IS_MLU_AVAILABLE,
            reason='MLU does not support for 64-bit floating point')),
    torch.half
])
146
147
def test_roialign_rotated(device, dtype):
    # check double only
148
    if dtype is torch.double:
149
150
        _test_roialign_rotated_gradcheck(device=device, dtype=dtype)
    _test_roialign_rotated_allclose(device=device, dtype=dtype)