test_ms_deformable_attn.py 14.9 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
5
import pytest
import torch

from mmcv.ops.multi_scale_deform_attn import (
6
7
    MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction,
    multi_scale_deformable_attn_pytorch)
limm's avatar
limm committed
8
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
9

pc's avatar
pc committed
10
_USING_PARROTS = True
11
_IS_AUTOCAST_AVAILABLE = True
pc's avatar
pc committed
12
13
14
15
16
17
try:
    from parrots.autograd import gradcheck
except ImportError:
    from torch.autograd import gradcheck
    _USING_PARROTS = False

18
19
20
21
22
23
24
25
try:
    # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
    # would be imported and used; we should test if our modules support it.
    from torch.cuda.amp import autocast
except ImportError:
    _IS_AUTOCAST_AVAILABLE = False
    pass

26

27
@pytest.mark.parametrize('device', [
Zaida Zhou's avatar
Zaida Zhou committed
28
29
30
31
    'cpu',
    pytest.param(
        'cuda:0',
        marks=pytest.mark.skipif(
32
33
34
35
36
            not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
    pytest.param(
        'mlu',
        marks=pytest.mark.skipif(
            not IS_MLU_AVAILABLE, reason='requires MLU support'))
Zaida Zhou's avatar
Zaida Zhou committed
37
])
38
def test_multiscale_deformable_attention(device):
Zaida Zhou's avatar
Zaida Zhou committed
39
40
41
42
43
44
    with pytest.raises(ValueError):
        # embed_dims must be divisible by num_heads,
        MultiScaleDeformableAttention(
            embed_dims=256,
            num_heads=7,
        )
45
    device = torch.device(device)
Zaida Zhou's avatar
Zaida Zhou committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    msda = MultiScaleDeformableAttention(
        embed_dims=3, num_levels=2, num_heads=3)
    msda.init_weights()
    num_query = 5
    bs = 1
    embed_dims = 3
    query = torch.rand(num_query, bs, embed_dims).to(device)
    key = torch.rand(num_query, bs, embed_dims).to(device)
    spatial_shapes = torch.Tensor([[2, 2], [1, 1]]).long().to(device)
    level_start_index = torch.Tensor([0, 4]).long().to(device)
    reference_points = torch.rand(bs, num_query, 2, 2).to(device)
    msda.to(device)
    msda(
        query,
        key,
        key,
        reference_points=reference_points,
        spatial_shapes=spatial_shapes,
        level_start_index=level_start_index)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

    # test with value_proj_ratio
    embed_dims = 6
    value_proj_ratio = 0.5
    query = torch.rand(num_query, bs, embed_dims).to(device)
    key = torch.rand(num_query, bs, embed_dims).to(device)
    msda = MultiScaleDeformableAttention(
        embed_dims=embed_dims,
        num_levels=2,
        num_heads=3,
        value_proj_ratio=value_proj_ratio)
    msda.init_weights()
    msda.to(device)
    msda(
        query,
        key,
        key,
        reference_points=reference_points,
        spatial_shapes=spatial_shapes,
        level_start_index=level_start_index)
Zaida Zhou's avatar
Zaida Zhou committed
85
86


87
88
89
90
def test_forward_multi_scale_deformable_attn_pytorch():
    N, M, D = 1, 2, 2
    Lq, L, P = 2, 2, 2
    shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
91
    S = sum((H * W).item() for H, W in shapes)
92
93
94
95
96
97
98
99
100
101
102
103
104
105

    torch.manual_seed(3)
    value = torch.rand(N, S, M, D) * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2)
    attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)

    multi_scale_deformable_attn_pytorch(value.double(), shapes,
                                        sampling_locations.double(),
                                        attention_weights.double()).detach()


106
@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason='requires CUDA support')
107
108
109
def test_forward_equal_with_pytorch_double():
    N, M, D = 1, 2, 2
    Lq, L, P = 2, 2, 2
110
    shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
111
112
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
113
    S = sum((H * W).item() for H, W in shapes)
114
115

    torch.manual_seed(3)
116
117
118
    value = torch.rand(N, S, M, D) * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2)
    attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
119
120
121
122
123
124
125
126
127
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2
    output_pytorch = multi_scale_deformable_attn_pytorch(
        value.double(), shapes, sampling_locations.double(),
        attention_weights.double()).detach().cpu()

    output_cuda = MultiScaleDeformableAttnFunction.apply(
128
129
130
        value.cuda().double(), shapes.cuda(), level_start_index.cuda(),
        sampling_locations.cuda().double(),
        attention_weights.cuda().double(), im2col_step).detach().cpu()
131
132
133
134
135
136
137
138
    assert torch.allclose(output_cuda, output_pytorch)
    max_abs_err = (output_cuda - output_pytorch).abs().max()
    max_rel_err = ((output_cuda - output_pytorch).abs() /
                   output_pytorch.abs()).max()
    assert max_abs_err < 1e-18
    assert max_rel_err < 1e-15


limm's avatar
limm committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
@pytest.mark.skipif(not IS_NPU_AVAILABLE, reason='requires NPU support')
def test_forward_equal_with_pytorch_npu():
    N, M, D = 6, 4, 8
    Lq, L, P = 10000, 4, 8
    shapes = torch.as_tensor([(60, 40), (30, 20), (16, 24), (53, 32)],
                             dtype=torch.int32)
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
    S = sum((H * W).item() for H, W in shapes)

    torch.manual_seed(3)
    value = torch.rand(N, S, M, D) * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2)
    attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2
    output_pytorch = multi_scale_deformable_attn_pytorch(
        value.float(), shapes, sampling_locations.float(),
        attention_weights.float()).detach().cpu()

    output_npu = MultiScaleDeformableAttnFunction.apply(
        value.npu().float(), shapes.npu(), level_start_index.npu(),
        sampling_locations.npu().float(),
        attention_weights.npu().float(), im2col_step).detach().cpu()
    assert torch.allclose(output_npu, output_pytorch)
    max_abs_err = (output_npu - output_pytorch).abs().max()
    max_rel_err = ((output_npu - output_pytorch).abs() /
                   output_pytorch.abs()).max()
    assert max_abs_err < 1e-18
    assert max_rel_err < 1e-15


173
174
175
176
177
178
179
180
181
182
183
@pytest.mark.parametrize('device', [
    pytest.param(
        'cuda',
        marks=pytest.mark.skipif(
            not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
    pytest.param(
        'mlu',
        marks=pytest.mark.skipif(
            not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_forward_equal_with_pytorch_float(device):
184
185
    N, M, D = 1, 2, 2
    Lq, L, P = 2, 2, 2
186
    shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
187
188
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
189
    S = sum((H * W).item() for H, W in shapes)
190
191

    torch.manual_seed(3)
192
193
194
    value = torch.rand(N, S, M, D) * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2)
    attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
195
196
197
198
199
200
201
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2
    output_pytorch = multi_scale_deformable_attn_pytorch(
        value, shapes, sampling_locations, attention_weights).detach().cpu()

202
203
204
205
206
207
208
    output_device = MultiScaleDeformableAttnFunction.apply(
        value.to(device), shapes.to(device), level_start_index.to(device),
        sampling_locations.to(device), attention_weights.to(device),
        im2col_step).detach().cpu()
    assert torch.allclose(output_device, output_pytorch, rtol=1e-2, atol=1e-3)
    max_abs_err = (output_device - output_pytorch).abs().max()
    max_rel_err = ((output_device - output_pytorch).abs() /
209
210
211
212
213
                   output_pytorch.abs()).max()
    assert max_abs_err < 1e-9
    assert max_rel_err < 1e-6


214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
@pytest.mark.skipif(
    not _IS_AUTOCAST_AVAILABLE, reason='requires autocast support')
@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason='requires CUDA support')
def test_forward_equal_with_autocast():
    N, M, D = 1, 2, 2
    Lq, L, P = 2, 2, 2
    shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
    S = sum((H * W).item() for H, W in shapes)

    torch.manual_seed(3)
    value = torch.rand(N, S, M, D) * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2)
    attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2
    output_pytorch = multi_scale_deformable_attn_pytorch(
        value, shapes, sampling_locations, attention_weights).detach().cpu()

    # float test
    dtype = torch.float
    with autocast(enabled=True):
        output_device = MultiScaleDeformableAttnFunction.apply(
            value.cuda().type(dtype), shapes.cuda(), level_start_index.cuda(),
            sampling_locations.cuda(), attention_weights.cuda(),
            im2col_step).detach().cpu()
    assert torch.allclose(output_device, output_pytorch, rtol=1e-2, atol=1e-3)
    max_abs_err = (output_device - output_pytorch).abs().max()
    max_rel_err = ((output_device - output_pytorch).abs() /
                   output_pytorch.abs()).max()
    assert max_abs_err < 1e-9
    assert max_rel_err < 1e-6

    # half test
    dtype = torch.half
    with autocast(enabled=True):
        output_device = MultiScaleDeformableAttnFunction.apply(
            value.cuda().type(dtype), shapes.cuda(), level_start_index.cuda(),
            sampling_locations.cuda(), attention_weights.cuda(),
            im2col_step).detach().cpu()
    assert torch.allclose(
        output_device, output_pytorch.half(), rtol=1e-2, atol=1e-3)
    max_abs_err = (output_device - output_pytorch).abs().max()
    max_rel_err = ((output_device - output_pytorch).abs() /
                   output_pytorch.abs()).max()
    assert max_abs_err < 1e-5
    assert max_rel_err < 1e-2


266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
@pytest.mark.parametrize('device', [
    pytest.param(
        'cuda',
        marks=pytest.mark.skipif(
            not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
    pytest.param(
        'mlu',
        marks=pytest.mark.skipif(
            not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
@pytest.mark.parametrize('dtype', [
    torch.float,
    pytest.param(
        torch.double,
        marks=pytest.mark.skipif(
            IS_MLU_AVAILABLE,
            reason='MLU does not support for 64-bit floating point')),
    torch.half
])
285
286
287
288
289
290
291
292
@pytest.mark.parametrize('channels', [
    4,
    30,
    32,
    64,
    71,
    1025,
])
293
def test_gradient_numerical(channels,
294
295
                            device,
                            dtype,
296
297
298
299
300
301
                            grad_value=True,
                            grad_sampling_loc=True,
                            grad_attn_weight=True):

    N, M, _ = 1, 2, 2
    Lq, L, P = 2, 2, 2
302
    shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).to(device)
303
304
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
305
    S = sum((H * W).item() for H, W in shapes)
306

307
308
309
    value = torch.rand(N, S, M, channels).to(device) * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2).to(device)
    attention_weights = torch.rand(N, Lq, M, L, P).to(device) + 1e-5
310
311
312
313
314
315
316
317
318
319
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2

    func = MultiScaleDeformableAttnFunction.apply

    value.requires_grad = grad_value
    sampling_locations.requires_grad = grad_sampling_loc
    attention_weights.requires_grad = grad_attn_weight
320
321
322
323
324
325
    if device == 'cuda':
        dtype = torch.double
        eps = 1e-6
    elif device == 'mlu':
        dtype = torch.float
        eps = 1e-4
pc's avatar
pc committed
326
327
    if _USING_PARROTS:
        assert gradcheck(
328
329
            func, (value.to(dtype), shapes, level_start_index,
                   sampling_locations.to(dtype), attention_weights.to(dtype),
pc's avatar
pc committed
330
                   im2col_step),
331
332
            no_grads=[shapes, level_start_index],
            eps=eps)
pc's avatar
pc committed
333
    else:
334
335
336
337
338
339
        assert gradcheck(
            func, (value.to(dtype), shapes, level_start_index,
                   sampling_locations.to(dtype), attention_weights.to(dtype),
                   im2col_step),
            eps=eps,
            atol=1e-2)
limm's avatar
limm committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403


@pytest.mark.skipif(not IS_NPU_AVAILABLE, reason='requires NPU support')
def test_backward_equal_with_pytorch_npu():
    N, M, D = 6, 4, 8
    Lq, L, P = 10000, 4, 8
    shapes = torch.as_tensor([(60, 40), (30, 20), (16, 24), (53, 32)],
                             dtype=torch.int32)
    level_start_index = torch.cat((shapes.new_zeros(
        (1, )), shapes.prod(1).cumsum(0)[:-1]))
    S = sum((H * W).item() for H, W in shapes)

    torch.manual_seed(3)
    value = torch.rand(N, S, M, D) * 0.01
    sampling_locations = torch.rand(N, Lq, M, L, P, 2)
    attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
    attention_weights /= attention_weights.sum(
        -1, keepdim=True).sum(
            -2, keepdim=True)
    im2col_step = 2
    value.requires_grad = True
    sampling_locations.requires_grad = True
    attention_weights.requires_grad = True
    output_pytorch = multi_scale_deformable_attn_pytorch(
        value.float(), shapes, sampling_locations.float(),
        attention_weights.float())
    grad_output_pytorch = torch.ones_like(output_pytorch)
    output_pytorch.backward(grad_output_pytorch)
    grad_value = value.grad.detach().cpu()
    grad_location = sampling_locations.grad.detach().cpu()
    grad_attn_weight = attention_weights.grad.detach().cpu()

    value_npu = value.npu()
    shapes_npu = shapes.npu()
    level_start_index_npu = level_start_index.npu()
    sampling_locations_npu = sampling_locations.npu()
    attention_weights_npu = attention_weights.npu()
    output_npu = MultiScaleDeformableAttnFunction.apply(
        value_npu.float(), shapes_npu, level_start_index_npu,
        sampling_locations_npu.float(), attention_weights_npu.float(),
        im2col_step)
    grad_output_npu = torch.ones_like(output_npu)
    output_npu.backward(grad_output_npu)
    grad_value_npu = value_npu.grad.detach().cpu()
    grad_location_npu = sampling_locations_npu.grad.detach().cpu()
    grad_attn_weight_npu = attention_weights_npu.grad.detach().cpu()
    assert torch.allclose(grad_value_npu, grad_value)
    max_abs_err_1 = (grad_value_npu - grad_value).abs().max()
    max_rel_err_1 = ((grad_value_npu - grad_value).abs() /
                     grad_value.abs()).max()
    assert max_abs_err_1 < 1e-5
    assert max_rel_err_1 < 1e-4
    assert torch.allclose(grad_location_npu, grad_location)
    max_abs_err_2 = (grad_location_npu - grad_location).abs().max()
    max_rel_err_2 = ((grad_location_npu - grad_location).abs() /
                     grad_location.abs()).max()
    assert max_abs_err_2 < 1e-5
    assert max_rel_err_2 < 1e-4
    assert torch.allclose(grad_attn_weight_npu, grad_attn_weight)
    max_abs_err_3 = (grad_attn_weight_npu - grad_attn_weight).abs().max()
    max_rel_err_3 = ((grad_attn_weight_npu - grad_attn_weight).abs() /
                     grad_attn_weight.abs()).max()
    assert max_abs_err_3 < 1e-5
    assert max_rel_err_3 < 1e-4