test_ms_deformable_attn.py 6.45 KB
Newer Older
1
2
3
4
import pytest
import torch

from mmcv.ops.multi_scale_deform_attn import (
5
6
    MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction,
    multi_scale_deformable_attn_pytorch)
7

pc's avatar
pc committed
8
9
10
11
12
13
14
_USING_PARROTS = True
try:
    from parrots.autograd import gradcheck
except ImportError:
    from torch.autograd import gradcheck
    _USING_PARROTS = False

15

limm's avatar
limm committed
16
@pytest.mark.parametrize('device_type', [
Zaida Zhou's avatar
Zaida Zhou committed
17
18
19
20
    'cpu',
    pytest.param(
        'cuda:0',
        marks=pytest.mark.skipif(
limm's avatar
limm committed
21
            not torch.cuda.is_available(), reason='requires CUDA support'))
Zaida Zhou's avatar
Zaida Zhou committed
22
])
limm's avatar
limm committed
23
24
def test_multiscale_deformable_attention(device_type):

Zaida Zhou's avatar
Zaida Zhou committed
25
26
27
28
29
30
    with pytest.raises(ValueError):
        # embed_dims must be divisible by num_heads,
        MultiScaleDeformableAttention(
            embed_dims=256,
            num_heads=7,
        )
limm's avatar
limm committed
31
    device = torch.device(device_type)
Zaida Zhou's avatar
Zaida Zhou committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    msda = MultiScaleDeformableAttention(
        embed_dims=3, num_levels=2, num_heads=3)
    msda.init_weights()
    num_query = 5
    bs = 1
    embed_dims = 3
    query = torch.rand(num_query, bs, embed_dims).to(device)
    key = torch.rand(num_query, bs, embed_dims).to(device)
    spatial_shapes = torch.Tensor([[2, 2], [1, 1]]).long().to(device)
    level_start_index = torch.Tensor([0, 4]).long().to(device)
    reference_points = torch.rand(bs, num_query, 2, 2).to(device)
    msda.to(device)
    msda(
        query,
        key,
        key,
        reference_points=reference_points,
        spatial_shapes=spatial_shapes,
        level_start_index=level_start_index)
51

Zaida Zhou's avatar
Zaida Zhou committed
52

53
54
55
56
def test_forward_multi_scale_deformable_attn_pytorch():
    N, M, D = 1, 2, 2
    Lq, L, P = 2, 2, 2
    shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
limm's avatar
limm committed
57
    S = sum([(H * W).item() for H, W in shapes])
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    torch.manual_seed(3)
    value = torch.rand(N, S, M, D) * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2)
    attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)

    multi_scale_deformable_attn_pytorch(value.double(), shapes,
                                        sampling_locations.double(),
                                        attention_weights.double()).detach()


limm's avatar
limm committed
72
73
@pytest.mark.skipif(
    not torch.cuda.is_available(), reason='requires CUDA support')
74
75
76
def test_forward_equal_with_pytorch_double():
    N, M, D = 1, 2, 2
    Lq, L, P = 2, 2, 2
limm's avatar
limm committed
77
    shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
78
79
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
limm's avatar
limm committed
80
    S = sum([(H * W).item() for H, W in shapes])
81
82

    torch.manual_seed(3)
limm's avatar
limm committed
83
84
85
    value = torch.rand(N, S, M, D).cuda() * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
86
87
88
89
90
91
92
93
94
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2
    output_pytorch = multi_scale_deformable_attn_pytorch(
        value.double(), shapes, sampling_locations.double(),
        attention_weights.double()).detach().cpu()

    output_cuda = MultiScaleDeformableAttnFunction.apply(
limm's avatar
limm committed
95
96
        value.double(), shapes, level_start_index, sampling_locations.double(),
        attention_weights.double(), im2col_step).detach().cpu()
97
98
99
100
101
102
103
104
    assert torch.allclose(output_cuda, output_pytorch)
    max_abs_err = (output_cuda - output_pytorch).abs().max()
    max_rel_err = ((output_cuda - output_pytorch).abs() /
                   output_pytorch.abs()).max()
    assert max_abs_err < 1e-18
    assert max_rel_err < 1e-15


105
@pytest.mark.skipif(
limm's avatar
limm committed
106
107
    not torch.cuda.is_available(), reason='requires CUDA support')
def test_forward_equal_with_pytorch_float():
108
109
    N, M, D = 1, 2, 2
    Lq, L, P = 2, 2, 2
limm's avatar
limm committed
110
    shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
111
112
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
limm's avatar
limm committed
113
    S = sum([(H * W).item() for H, W in shapes])
114
115

    torch.manual_seed(3)
limm's avatar
limm committed
116
117
118
    value = torch.rand(N, S, M, D).cuda() * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
119
120
121
122
123
124
125
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2
    output_pytorch = multi_scale_deformable_attn_pytorch(
        value, shapes, sampling_locations, attention_weights).detach().cpu()

limm's avatar
limm committed
126
127
128
129
130
131
    output_cuda = MultiScaleDeformableAttnFunction.apply(
        value, shapes, level_start_index, sampling_locations,
        attention_weights, im2col_step).detach().cpu()
    assert torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
    max_abs_err = (output_cuda - output_pytorch).abs().max()
    max_rel_err = ((output_cuda - output_pytorch).abs() /
132
133
134
135
136
                   output_pytorch.abs()).max()
    assert max_abs_err < 1e-9
    assert max_rel_err < 1e-6


limm's avatar
limm committed
137
138
@pytest.mark.skipif(
    not torch.cuda.is_available(), reason='requires CUDA support')
139
140
141
142
143
144
145
146
@pytest.mark.parametrize('channels', [
    4,
    30,
    32,
    64,
    71,
    1025,
])
147
148
149
150
151
152
153
def test_gradient_numerical(channels,
                            grad_value=True,
                            grad_sampling_loc=True,
                            grad_attn_weight=True):

    N, M, _ = 1, 2, 2
    Lq, L, P = 2, 2, 2
limm's avatar
limm committed
154
    shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).cuda()
155
156
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
limm's avatar
limm committed
157
    S = sum([(H * W).item() for H, W in shapes])
158

limm's avatar
limm committed
159
160
161
    value = torch.rand(N, S, M, channels).cuda() * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
162
163
164
165
166
167
168
169
170
171
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2

    func = MultiScaleDeformableAttnFunction.apply

    value.requires_grad = grad_value
    sampling_locations.requires_grad = grad_sampling_loc
    attention_weights.requires_grad = grad_attn_weight
pc's avatar
pc committed
172
173
    if _USING_PARROTS:
        assert gradcheck(
limm's avatar
limm committed
174
175
            func, (value.double(), shapes, level_start_index,
                   sampling_locations.double(), attention_weights.double(),
pc's avatar
pc committed
176
                   im2col_step),
limm's avatar
limm committed
177
            no_grads=[shapes, level_start_index])
pc's avatar
pc committed
178
    else:
limm's avatar
limm committed
179
180
181
        assert gradcheck(func, (value.double(), shapes, level_start_index,
                                sampling_locations.double(),
                                attention_weights.double(), im2col_step))