Commit a0939977 authored by ZShaopeng's avatar ZShaopeng Committed by Zaida Zhou
Browse files

[Feature] Support MultiScaleDeformableAttn with cambricon MLU backend

parent 193de43b
......@@ -33,7 +33,7 @@ We implement common ops used in detection, segmentation, etc.
| MergeCells | | √ | | |
| MinAreaPolygon | | √ | | |
| ModulatedDeformConv2d | √ | √ | | |
| MultiScaleDeformableAttn | | √ | | |
| MultiScaleDeformableAttn | | √ | | |
| NMS | √ | √ | √ | |
| NMSRotated | √ | √ | | |
| NMSQuadri | √ | √ | | |
......
......@@ -33,7 +33,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| MergeCells | | √ | | |
| MinAreaPolygon | | √ | | |
| ModulatedDeformConv2d | √ | √ | | |
| MultiScaleDeformableAttn | | √ | | |
| MultiScaleDeformableAttn | | √ | | |
| NMS | √ | √ | √ | |
| NMSRotated | √ | √ | | |
| NMSQuadri | √ | √ | | |
......
......@@ -362,4 +362,37 @@ __mlu_func__ inline void convertFloat2half(half *dst, float *src,
#endif
}
/*!
* @brief recursiveSumPool.
* @param[in,out] dst
* Pointer to NRAM that stores the input and output data.
* @param[in] low_dim
* Which is the number of low dim.
* @param[in] high_dim
* Which is the number of high dim.
* @param[in] kernel_limit
* Which is the high_dim of sumpool per time.
******************************************************************************/
template <typename T>
__mlu_func__ void recursiveSumPool(T *dst, int low_dim, int high_dim,
int kernel_limit) {
for (; high_dim > 1;) {
int repeat_s = high_dim / kernel_limit;
int remain_s = high_dim % kernel_limit;
if (remain_s) {
__bang_sumpool((T *)dst, (T *)dst, low_dim, 1, remain_s, 1, remain_s, 1,
1);
}
if (repeat_s) {
__bang_sumpool((T *)dst + (remain_s > 0 ? low_dim : 0),
(T *)dst + remain_s * low_dim, low_dim,
kernel_limit * repeat_s, 1, kernel_limit, 1, 1,
kernel_limit);
}
high_dim = repeat_s + (bool)remain_s;
}
return;
}
#endif // COMMON_MLU_HELPER_HPP_
This diff is collapsed.
This diff is collapsed.
......@@ -12,6 +12,7 @@ from mmengine.registry import MODELS
from mmengine.utils import deprecated_api_warning
from torch.autograd.function import Function, once_differentiable
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
......@@ -26,7 +27,7 @@ class MultiScaleDeformableAttnFunction(Function):
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
im2col_step: torch.Tensor) -> torch.Tensor:
"""GPU version of multi-scale deformable attention.
"""GPU/MLU version of multi-scale deformable attention.
Args:
value (torch.Tensor): The value has shape
......@@ -63,7 +64,7 @@ class MultiScaleDeformableAttnFunction(Function):
@staticmethod
@once_differentiable
def backward(ctx, grad_output: torch.Tensor) -> tuple:
"""GPU version of backward function.
"""GPU/MLU version of backward function.
Args:
grad_output (torch.Tensor): Gradient of output tensor of forward.
......@@ -346,7 +347,8 @@ class MultiScaleDeformableAttention(BaseModule):
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available() and value.is_cuda:
if ((IS_CUDA_AVAILABLE and value.is_cuda)
or (IS_MLU_AVAILABLE and value.is_mlu)):
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
......
......@@ -5,6 +5,7 @@ import torch
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction,
multi_scale_deformable_attn_pytorch)
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
_USING_PARROTS = True
try:
......@@ -14,22 +15,25 @@ except ImportError:
_USING_PARROTS = False
@pytest.mark.parametrize('device_type', [
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda:0',
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support'))
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_multiscale_deformable_attention(device_type):
def test_multiscale_deformable_attention(device):
with pytest.raises(ValueError):
# embed_dims must be divisible by num_heads,
MultiScaleDeformableAttention(
embed_dims=256,
num_heads=7,
)
device = torch.device(device_type)
device = torch.device(device)
msda = MultiScaleDeformableAttention(
embed_dims=3, num_levels=2, num_heads=3)
msda.init_weights()
......@@ -70,20 +74,19 @@ def test_forward_multi_scale_deformable_attn_pytorch():
attention_weights.double()).detach()
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason='requires CUDA support')
def test_forward_equal_with_pytorch_double():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
......@@ -93,8 +96,9 @@ def test_forward_equal_with_pytorch_double():
attention_weights.double()).detach().cpu()
output_cuda = MultiScaleDeformableAttnFunction.apply(
value.double(), shapes, level_start_index, sampling_locations.double(),
attention_weights.double(), im2col_step).detach().cpu()
value.cuda().double(), shapes.cuda(), level_start_index.cuda(),
sampling_locations.cuda().double(),
attention_weights.cuda().double(), im2col_step).detach().cpu()
assert torch.allclose(output_cuda, output_pytorch)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() /
......@@ -103,20 +107,28 @@ def test_forward_equal_with_pytorch_double():
assert max_rel_err < 1e-15
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_forward_equal_with_pytorch_float():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_forward_equal_with_pytorch_float(device):
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
......@@ -124,19 +136,37 @@ def test_forward_equal_with_pytorch_float():
output_pytorch = multi_scale_deformable_attn_pytorch(
value, shapes, sampling_locations, attention_weights).detach().cpu()
output_cuda = MultiScaleDeformableAttnFunction.apply(
value, shapes, level_start_index, sampling_locations,
attention_weights, im2col_step).detach().cpu()
assert torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() /
output_device = MultiScaleDeformableAttnFunction.apply(
value.to(device), shapes.to(device), level_start_index.to(device),
sampling_locations.to(device), attention_weights.to(device),
im2col_step).detach().cpu()
assert torch.allclose(output_device, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_device - output_pytorch).abs().max()
max_rel_err = ((output_device - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-9
assert max_rel_err < 1e-6
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
@pytest.mark.parametrize('dtype', [
torch.float,
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE,
reason='MLU does not support for 64-bit floating point')),
torch.half
])
@pytest.mark.parametrize('channels', [
4,
30,
......@@ -146,20 +176,22 @@ def test_forward_equal_with_pytorch_float():
1025,
])
def test_gradient_numerical(channels,
device,
dtype,
grad_value=True,
grad_sampling_loc=True,
grad_attn_weight=True):
N, M, _ = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).cuda()
shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).to(device)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)
value = torch.rand(N, S, M, channels).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
value = torch.rand(N, S, M, channels).to(device) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).to(device)
attention_weights = torch.rand(N, Lq, M, L, P).to(device) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
......@@ -170,13 +202,23 @@ def test_gradient_numerical(channels,
value.requires_grad = grad_value
sampling_locations.requires_grad = grad_sampling_loc
attention_weights.requires_grad = grad_attn_weight
if device == 'cuda':
dtype = torch.double
eps = 1e-6
elif device == 'mlu':
dtype = torch.float
eps = 1e-4
if _USING_PARROTS:
assert gradcheck(
func, (value.double(), shapes, level_start_index,
sampling_locations.double(), attention_weights.double(),
func, (value.to(dtype), shapes, level_start_index,
sampling_locations.to(dtype), attention_weights.to(dtype),
im2col_step),
no_grads=[shapes, level_start_index])
no_grads=[shapes, level_start_index],
eps=eps)
else:
assert gradcheck(func, (value.double(), shapes, level_start_index,
sampling_locations.double(),
attention_weights.double(), im2col_step))
assert gradcheck(
func, (value.to(dtype), shapes, level_start_index,
sampling_locations.to(dtype), attention_weights.to(dtype),
im2col_step),
eps=eps,
atol=1e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment