Unverified Commit 97e5bada authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

continue PR #1223 (#1404)



* fix MultiScaleDeformableAttention inference issue on cpu model

* fix lint

* add unintest

* remove some code

* Update tests/test_ops/test_ms_deformable_attn.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* fix device

* remove device

* add more device

* refactor unittest
Co-authored-by: default avatarzhicheng huang <zhichenghzc@gmail.com>
Co-authored-by: default avatarzhangshilong <2392587229zsl@gmail.com>
Co-authored-by: default avatarShilong Zhang <61961338+jshilong@users.noreply.github.com>
parent 1cd01db9
......@@ -341,14 +341,13 @@ class MultiScaleDeformableAttention(BaseModule):
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available():
if torch.cuda.is_available() and value.is_cuda:
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
......
......@@ -13,6 +13,43 @@ except ImportError:
_USING_PARROTS = False
@pytest.mark.parametrize('device_type', [
'cpu',
pytest.param(
'cuda:0',
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support'))
])
def test_multiscale_deformable_attention(device_type):
with pytest.raises(ValueError):
# embed_dims must be divisible by num_heads,
MultiScaleDeformableAttention(
embed_dims=256,
num_heads=7,
)
device = torch.device(device_type)
msda = MultiScaleDeformableAttention(
embed_dims=3, num_levels=2, num_heads=3)
msda.init_weights()
num_query = 5
bs = 1
embed_dims = 3
query = torch.rand(num_query, bs, embed_dims).to(device)
key = torch.rand(num_query, bs, embed_dims).to(device)
spatial_shapes = torch.Tensor([[2, 2], [1, 1]]).long().to(device)
level_start_index = torch.Tensor([0, 4]).long().to(device)
reference_points = torch.rand(bs, num_query, 2, 2).to(device)
msda.to(device)
msda(
query,
key,
key,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index)
def test_forward_multi_scale_deformable_attn_pytorch():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
......@@ -142,20 +179,3 @@ def test_gradient_numerical(channels,
assert gradcheck(func, (value.double(), shapes, level_start_index,
sampling_locations.double(),
attention_weights.double(), im2col_step))
def test_multiscale_deformable_attention():
with pytest.raises(ValueError):
# embed_dims must be divisible by num_heads,
MultiScaleDeformableAttention(
embed_dims=256,
num_heads=7,
)
with pytest.raises(ValueError):
# embed_dims must be divisible by num_heads,
MultiScaleDeformableAttention(
embed_dims=256,
num_heads=7,
)
MultiScaleDeformableAttention(embed_dims=256, num_heads=8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment