test_ms_deformable_attn.py 6.49 KB
Newer Older
limm's avatar
limm committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
5
import pytest
import torch

from mmcv.ops.multi_scale_deform_attn import (
6
7
    MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction,
    multi_scale_deformable_attn_pytorch)
8

pc's avatar
pc committed
9
10
11
12
13
14
15
_USING_PARROTS = True
try:
    from parrots.autograd import gradcheck
except ImportError:
    from torch.autograd import gradcheck
    _USING_PARROTS = False

16

limm's avatar
limm committed
17
@pytest.mark.parametrize('device_type', [
Zaida Zhou's avatar
Zaida Zhou committed
18
19
20
21
    'cpu',
    pytest.param(
        'cuda:0',
        marks=pytest.mark.skipif(
limm's avatar
limm committed
22
            not torch.cuda.is_available(), reason='requires CUDA support'))
Zaida Zhou's avatar
Zaida Zhou committed
23
])
limm's avatar
limm committed
24
25
def test_multiscale_deformable_attention(device_type):

Zaida Zhou's avatar
Zaida Zhou committed
26
27
28
29
30
31
    with pytest.raises(ValueError):
        # embed_dims must be divisible by num_heads,
        MultiScaleDeformableAttention(
            embed_dims=256,
            num_heads=7,
        )
limm's avatar
limm committed
32
    device = torch.device(device_type)
Zaida Zhou's avatar
Zaida Zhou committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    msda = MultiScaleDeformableAttention(
        embed_dims=3, num_levels=2, num_heads=3)
    msda.init_weights()
    num_query = 5
    bs = 1
    embed_dims = 3
    query = torch.rand(num_query, bs, embed_dims).to(device)
    key = torch.rand(num_query, bs, embed_dims).to(device)
    spatial_shapes = torch.Tensor([[2, 2], [1, 1]]).long().to(device)
    level_start_index = torch.Tensor([0, 4]).long().to(device)
    reference_points = torch.rand(bs, num_query, 2, 2).to(device)
    msda.to(device)
    msda(
        query,
        key,
        key,
        reference_points=reference_points,
        spatial_shapes=spatial_shapes,
        level_start_index=level_start_index)
52

Zaida Zhou's avatar
Zaida Zhou committed
53

54
55
56
57
def test_forward_multi_scale_deformable_attn_pytorch():
    N, M, D = 1, 2, 2
    Lq, L, P = 2, 2, 2
    shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
limm's avatar
limm committed
58
    S = sum((H * W).item() for H, W in shapes)
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    torch.manual_seed(3)
    value = torch.rand(N, S, M, D) * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2)
    attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)

    multi_scale_deformable_attn_pytorch(value.double(), shapes,
                                        sampling_locations.double(),
                                        attention_weights.double()).detach()


limm's avatar
limm committed
73
74
@pytest.mark.skipif(
    not torch.cuda.is_available(), reason='requires CUDA support')
75
76
77
def test_forward_equal_with_pytorch_double():
    N, M, D = 1, 2, 2
    Lq, L, P = 2, 2, 2
limm's avatar
limm committed
78
    shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
79
80
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
limm's avatar
limm committed
81
    S = sum((H * W).item() for H, W in shapes)
82
83

    torch.manual_seed(3)
limm's avatar
limm committed
84
85
86
    value = torch.rand(N, S, M, D).cuda() * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
87
88
89
90
91
92
93
94
95
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2
    output_pytorch = multi_scale_deformable_attn_pytorch(
        value.double(), shapes, sampling_locations.double(),
        attention_weights.double()).detach().cpu()

    output_cuda = MultiScaleDeformableAttnFunction.apply(
limm's avatar
limm committed
96
97
        value.double(), shapes, level_start_index, sampling_locations.double(),
        attention_weights.double(), im2col_step).detach().cpu()
98
99
100
101
102
103
104
105
    assert torch.allclose(output_cuda, output_pytorch)
    max_abs_err = (output_cuda - output_pytorch).abs().max()
    max_rel_err = ((output_cuda - output_pytorch).abs() /
                   output_pytorch.abs()).max()
    assert max_abs_err < 1e-18
    assert max_rel_err < 1e-15


106
@pytest.mark.skipif(
limm's avatar
limm committed
107
108
    not torch.cuda.is_available(), reason='requires CUDA support')
def test_forward_equal_with_pytorch_float():
109
110
    N, M, D = 1, 2, 2
    Lq, L, P = 2, 2, 2
limm's avatar
limm committed
111
    shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
112
113
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
limm's avatar
limm committed
114
    S = sum((H * W).item() for H, W in shapes)
115
116

    torch.manual_seed(3)
limm's avatar
limm committed
117
118
119
    value = torch.rand(N, S, M, D).cuda() * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
120
121
122
123
124
125
126
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2
    output_pytorch = multi_scale_deformable_attn_pytorch(
        value, shapes, sampling_locations, attention_weights).detach().cpu()

limm's avatar
limm committed
127
128
129
130
131
132
    output_cuda = MultiScaleDeformableAttnFunction.apply(
        value, shapes, level_start_index, sampling_locations,
        attention_weights, im2col_step).detach().cpu()
    assert torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
    max_abs_err = (output_cuda - output_pytorch).abs().max()
    max_rel_err = ((output_cuda - output_pytorch).abs() /
133
134
135
136
137
                   output_pytorch.abs()).max()
    assert max_abs_err < 1e-9
    assert max_rel_err < 1e-6


limm's avatar
limm committed
138
139
@pytest.mark.skipif(
    not torch.cuda.is_available(), reason='requires CUDA support')
140
141
142
143
144
145
146
147
@pytest.mark.parametrize('channels', [
    4,
    30,
    32,
    64,
    71,
    1025,
])
148
149
150
151
152
153
154
def test_gradient_numerical(channels,
                            grad_value=True,
                            grad_sampling_loc=True,
                            grad_attn_weight=True):

    N, M, _ = 1, 2, 2
    Lq, L, P = 2, 2, 2
limm's avatar
limm committed
155
    shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).cuda()
156
157
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
limm's avatar
limm committed
158
    S = sum((H * W).item() for H, W in shapes)
159

limm's avatar
limm committed
160
161
162
    value = torch.rand(N, S, M, channels).cuda() * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
163
164
165
166
167
168
169
170
171
172
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2

    func = MultiScaleDeformableAttnFunction.apply

    value.requires_grad = grad_value
    sampling_locations.requires_grad = grad_sampling_loc
    attention_weights.requires_grad = grad_attn_weight
pc's avatar
pc committed
173
174
    if _USING_PARROTS:
        assert gradcheck(
limm's avatar
limm committed
175
176
            func, (value.double(), shapes, level_start_index,
                   sampling_locations.double(), attention_weights.double(),
pc's avatar
pc committed
177
                   im2col_step),
limm's avatar
limm committed
178
            no_grads=[shapes, level_start_index])
pc's avatar
pc committed
179
    else:
limm's avatar
limm committed
180
181
182
        assert gradcheck(func, (value.double(), shapes, level_start_index,
                                sampling_locations.double(),
                                attention_weights.double(), im2col_step))