test_tin_shift.py 4.39 KB
Newer Older
Jintao Lin's avatar
Jintao Lin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os

import numpy as np
import pytest
import torch

_USING_PARROTS = True
try:
    from parrots.autograd import gradcheck
except ImportError:
    from torch.autograd import gradcheck

    _USING_PARROTS = False

cur_dir = os.path.dirname(os.path.abspath(__file__))

inputs = ([[[[0.4369, -3.7571], [-1.1835, -1.6374], [0.9534, -0.1321]],
            [[-0.4658, 0.2162], [-0.8135, -0.3903], [-0.1720, -0.0599]],
            [[0.4851, 1.8224], [0.8973, 0.3779], [2.3454, 1.0319]],
            [[0.0420, 0.3574], [0.7641, 0.2384], [0.2759, 0.4931]]],
           [[[-0.5897, 0.7544], [1.0593, 0.8388], [-0.5732, 0.5692]],
            [[-0.6766, -1.4657], [1.2362, 0.4913], [-1.1820, -1.4341]],
            [[0.6476, -0.7391], [1.4314, -0.3522], [0.8401, -0.7757]],
            [[1.4306, 0.9726], [1.0518, -0.8820], [-0.5129, -0.7876]]]])

shifts = [([[1, 0, 1, -2], [-2, 1, -1, 1]]), ([[2, 1, 2, -1], [-1, 2, 0, 2]])]

outputs = [([[[[0.4369, -3.7571], [-1.1835, -1.6374], [0.9534, -0.1321]],
              [[-0.4658, 0.2162], [-0.8135, -0.3903], [-0.1720, -0.0599]],
              [[0.4851, 1.8224], [0.8973, 0.3779], [2.3454, 1.0319]],
              [[0.0420, 0.3574], [0.7641, 0.2384], [0.2759, 0.4931]]],
             [[[0.6476, -0.7391], [1.4314, -0.3522], [0.8401, -0.7757]],
              [[1.4306, 0.9726], [1.0518, -0.8820], [-0.5129, -0.7876]],
              [[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000]],
              [[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000]]]]),
           ([[[[0.4369, -3.7571], [-1.1835, -1.6374], [0.9534, -0.1321]],
              [[-0.4658, 0.2162], [-0.8135, -0.3903], [-0.1720, -0.0599]],
              [[0.4851, 1.8224], [0.8973, 0.3779], [2.3454, 1.0319]],
              [[0.0420, 0.3574], [0.7641, 0.2384], [0.2759, 0.4931]]],
             [[[-0.6766, -1.4657], [1.2362, 0.4913], [-1.1820, -1.4341]],
              [[0.6476, -0.7391], [1.4314, -0.3522], [0.8401, -0.7757]],
              [[1.4306, 0.9726], [1.0518, -0.8820], [-0.5129, -0.7876]],
              [[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000]]]])]

grads = [[[[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]],
           [[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]]],
          [[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]],
           [[0., 0.], [0., 0.], [0., 0.]], [[0., 0.], [0., 0.], [0., 0.]]]],
         [[[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]],
           [[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]]],
          [[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]],
           [[1., 1.], [1., 1.], [1., 1.]], [[0., 0.], [0., 0.], [0., 0.]]]]]


def _test_tinshift_gradcheck(dtype):
    try:
        from mmcv.ops import tin_shift
    except ModuleNotFoundError:
        pytest.skip('TinShift op is not successfully compiled')

    if dtype == torch.half:
        pytest.skip('"add_cpu/sub_cpu" not implemented for Half')

    for shift in shifts:
        np_input = np.array(inputs)
        np_shift = np.array(shift)

        x = torch.tensor(
            np_input, dtype=dtype, device='cuda', requires_grad=True)
        shift = torch.tensor(np_shift, device='cuda').int()
        if torch.__version__ == 'parrots':
            gradcheck(tin_shift, (x, shift))
        else:
            gradcheck(tin_shift, (x, shift), atol=1, rtol=0.1)


def _test_tinshift_allclose(dtype):
    try:
        from mmcv.ops import tin_shift
    except ModuleNotFoundError:
        pytest.skip('TinShift op is not successfully compiled')

    for shift, output, grad in zip(shifts, outputs, grads):
        np_input = np.array(inputs)
        np_shift = np.array(shift)
        np_output = np.array(output)
        np_grad = np.array(grad)

        x = torch.tensor(
            np_input, dtype=dtype, device='cuda', requires_grad=True)
        shift = torch.tensor(np_shift, device='cuda').int()

        output = tin_shift(x, shift)
        output.backward(torch.ones_like(output))
        assert np.allclose(
            output.data.type(torch.float).cpu().numpy(), np_output, 1e-3)
        assert np.allclose(
            x.grad.data.type(torch.float).cpu().numpy(), np_grad, 1e-3)


@pytest.mark.skipif(
    not torch.cuda.is_available(), reason='requires CUDA support')
@pytest.mark.parametrize('dtype', [torch.float, torch.double, torch.half])
def test_tinshift(dtype):
xiabo's avatar
xiabo committed
105
    # _test_tinshift_allclose(dtype=dtype)
Jintao Lin's avatar
Jintao Lin committed
106
    _test_tinshift_gradcheck(dtype=dtype)