test_ms_deformable_attn.py 10.1 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
5
import pytest
import torch

from mmcv.ops.multi_scale_deform_attn import (
6
7
    MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction,
    multi_scale_deformable_attn_pytorch)
8
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
9

pc's avatar
pc committed
10
_USING_PARROTS = True
11
_IS_AUTOCAST_AVAILABLE = True
pc's avatar
pc committed
12
13
14
15
16
17
try:
    from parrots.autograd import gradcheck
except ImportError:
    from torch.autograd import gradcheck
    _USING_PARROTS = False

18
19
20
21
22
23
24
25
try:
    # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
    # would be imported and used; we should test if our modules support it.
    from torch.cuda.amp import autocast
except ImportError:
    _IS_AUTOCAST_AVAILABLE = False
    pass

26

27
@pytest.mark.parametrize('device', [
Zaida Zhou's avatar
Zaida Zhou committed
28
29
30
31
    'cpu',
    pytest.param(
        'cuda:0',
        marks=pytest.mark.skipif(
32
33
34
35
36
            not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
    pytest.param(
        'mlu',
        marks=pytest.mark.skipif(
            not IS_MLU_AVAILABLE, reason='requires MLU support'))
Zaida Zhou's avatar
Zaida Zhou committed
37
])
38
def test_multiscale_deformable_attention(device):
Zaida Zhou's avatar
Zaida Zhou committed
39
40
41
42
43
44
    with pytest.raises(ValueError):
        # embed_dims must be divisible by num_heads,
        MultiScaleDeformableAttention(
            embed_dims=256,
            num_heads=7,
        )
45
    device = torch.device(device)
Zaida Zhou's avatar
Zaida Zhou committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    msda = MultiScaleDeformableAttention(
        embed_dims=3, num_levels=2, num_heads=3)
    msda.init_weights()
    num_query = 5
    bs = 1
    embed_dims = 3
    query = torch.rand(num_query, bs, embed_dims).to(device)
    key = torch.rand(num_query, bs, embed_dims).to(device)
    spatial_shapes = torch.Tensor([[2, 2], [1, 1]]).long().to(device)
    level_start_index = torch.Tensor([0, 4]).long().to(device)
    reference_points = torch.rand(bs, num_query, 2, 2).to(device)
    msda.to(device)
    msda(
        query,
        key,
        key,
        reference_points=reference_points,
        spatial_shapes=spatial_shapes,
        level_start_index=level_start_index)


67
68
69
70
def test_forward_multi_scale_deformable_attn_pytorch():
    N, M, D = 1, 2, 2
    Lq, L, P = 2, 2, 2
    shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
71
    S = sum((H * W).item() for H, W in shapes)
72
73
74
75
76
77
78
79
80
81
82
83
84
85

    torch.manual_seed(3)
    value = torch.rand(N, S, M, D) * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2)
    attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)

    multi_scale_deformable_attn_pytorch(value.double(), shapes,
                                        sampling_locations.double(),
                                        attention_weights.double()).detach()


86
@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason='requires CUDA support')
87
88
89
def test_forward_equal_with_pytorch_double():
    N, M, D = 1, 2, 2
    Lq, L, P = 2, 2, 2
90
    shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
91
92
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
93
    S = sum((H * W).item() for H, W in shapes)
94
95

    torch.manual_seed(3)
96
97
98
    value = torch.rand(N, S, M, D) * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2)
    attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
99
100
101
102
103
104
105
106
107
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2
    output_pytorch = multi_scale_deformable_attn_pytorch(
        value.double(), shapes, sampling_locations.double(),
        attention_weights.double()).detach().cpu()

    output_cuda = MultiScaleDeformableAttnFunction.apply(
108
109
110
        value.cuda().double(), shapes.cuda(), level_start_index.cuda(),
        sampling_locations.cuda().double(),
        attention_weights.cuda().double(), im2col_step).detach().cpu()
111
112
113
114
115
116
117
118
    assert torch.allclose(output_cuda, output_pytorch)
    max_abs_err = (output_cuda - output_pytorch).abs().max()
    max_rel_err = ((output_cuda - output_pytorch).abs() /
                   output_pytorch.abs()).max()
    assert max_abs_err < 1e-18
    assert max_rel_err < 1e-15


119
120
121
122
123
124
125
126
127
128
129
@pytest.mark.parametrize('device', [
    pytest.param(
        'cuda',
        marks=pytest.mark.skipif(
            not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
    pytest.param(
        'mlu',
        marks=pytest.mark.skipif(
            not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_forward_equal_with_pytorch_float(device):
130
131
    N, M, D = 1, 2, 2
    Lq, L, P = 2, 2, 2
132
    shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
133
134
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
135
    S = sum((H * W).item() for H, W in shapes)
136
137

    torch.manual_seed(3)
138
139
140
    value = torch.rand(N, S, M, D) * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2)
    attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
141
142
143
144
145
146
147
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2
    output_pytorch = multi_scale_deformable_attn_pytorch(
        value, shapes, sampling_locations, attention_weights).detach().cpu()

148
149
150
151
152
153
154
    output_device = MultiScaleDeformableAttnFunction.apply(
        value.to(device), shapes.to(device), level_start_index.to(device),
        sampling_locations.to(device), attention_weights.to(device),
        im2col_step).detach().cpu()
    assert torch.allclose(output_device, output_pytorch, rtol=1e-2, atol=1e-3)
    max_abs_err = (output_device - output_pytorch).abs().max()
    max_rel_err = ((output_device - output_pytorch).abs() /
155
156
157
158
159
                   output_pytorch.abs()).max()
    assert max_abs_err < 1e-9
    assert max_rel_err < 1e-6


160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
@pytest.mark.skipif(
    not _IS_AUTOCAST_AVAILABLE, reason='requires autocast support')
@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason='requires CUDA support')
def test_forward_equal_with_autocast():
    N, M, D = 1, 2, 2
    Lq, L, P = 2, 2, 2
    shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
    S = sum((H * W).item() for H, W in shapes)

    torch.manual_seed(3)
    value = torch.rand(N, S, M, D) * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2)
    attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2
    output_pytorch = multi_scale_deformable_attn_pytorch(
        value, shapes, sampling_locations, attention_weights).detach().cpu()

    # float test
    dtype = torch.float
    with autocast(enabled=True):
        output_device = MultiScaleDeformableAttnFunction.apply(
            value.cuda().type(dtype), shapes.cuda(), level_start_index.cuda(),
            sampling_locations.cuda(), attention_weights.cuda(),
            im2col_step).detach().cpu()
    assert torch.allclose(output_device, output_pytorch, rtol=1e-2, atol=1e-3)
    max_abs_err = (output_device - output_pytorch).abs().max()
    max_rel_err = ((output_device - output_pytorch).abs() /
                   output_pytorch.abs()).max()
    assert max_abs_err < 1e-9
    assert max_rel_err < 1e-6

    # half test
    dtype = torch.half
    with autocast(enabled=True):
        output_device = MultiScaleDeformableAttnFunction.apply(
            value.cuda().type(dtype), shapes.cuda(), level_start_index.cuda(),
            sampling_locations.cuda(), attention_weights.cuda(),
            im2col_step).detach().cpu()
    assert torch.allclose(
        output_device, output_pytorch.half(), rtol=1e-2, atol=1e-3)
    max_abs_err = (output_device - output_pytorch).abs().max()
    max_rel_err = ((output_device - output_pytorch).abs() /
                   output_pytorch.abs()).max()
    assert max_abs_err < 1e-5
    assert max_rel_err < 1e-2


212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
@pytest.mark.parametrize('device', [
    pytest.param(
        'cuda',
        marks=pytest.mark.skipif(
            not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
    pytest.param(
        'mlu',
        marks=pytest.mark.skipif(
            not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
@pytest.mark.parametrize('dtype', [
    torch.float,
    pytest.param(
        torch.double,
        marks=pytest.mark.skipif(
            IS_MLU_AVAILABLE,
            reason='MLU does not support for 64-bit floating point')),
    torch.half
])
231
232
233
234
235
236
237
238
@pytest.mark.parametrize('channels', [
    4,
    30,
    32,
    64,
    71,
    1025,
])
239
def test_gradient_numerical(channels,
240
241
                            device,
                            dtype,
242
243
244
245
246
247
                            grad_value=True,
                            grad_sampling_loc=True,
                            grad_attn_weight=True):

    N, M, _ = 1, 2, 2
    Lq, L, P = 2, 2, 2
248
    shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).to(device)
249
250
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
251
    S = sum((H * W).item() for H, W in shapes)
252

253
254
255
    value = torch.rand(N, S, M, channels).to(device) * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2).to(device)
    attention_weights = torch.rand(N, Lq, M, L, P).to(device) + 1e-5
256
257
258
259
260
261
262
263
264
265
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2

    func = MultiScaleDeformableAttnFunction.apply

    value.requires_grad = grad_value
    sampling_locations.requires_grad = grad_sampling_loc
    attention_weights.requires_grad = grad_attn_weight
266
267
268
269
270
271
    if device == 'cuda':
        dtype = torch.double
        eps = 1e-6
    elif device == 'mlu':
        dtype = torch.float
        eps = 1e-4
pc's avatar
pc committed
272
273
    if _USING_PARROTS:
        assert gradcheck(
274
275
            func, (value.to(dtype), shapes, level_start_index,
                   sampling_locations.to(dtype), attention_weights.to(dtype),
pc's avatar
pc committed
276
                   im2col_step),
277
278
            no_grads=[shapes, level_start_index],
            eps=eps)
pc's avatar
pc committed
279
    else:
280
281
282
283
284
285
        assert gradcheck(
            func, (value.to(dtype), shapes, level_start_index,
                   sampling_locations.to(dtype), attention_weights.to(dtype),
                   im2col_step),
            eps=eps,
            atol=1e-2)