Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
732ff509
Unverified
Commit
732ff509
authored
May 25, 2021
by
pc
Committed by
GitHub
May 25, 2021
Browse files
Add ms_deformable_attn in parrots (#1042)
parent
a6377240
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
609 additions
and
58 deletions
+609
-58
mmcv/ops/csrc/parrots/ms_deform_attn.cpp
mmcv/ops/csrc/parrots/ms_deform_attn.cpp
+79
-0
mmcv/ops/csrc/parrots/ms_deform_attn_cuda.cu
mmcv/ops/csrc/parrots/ms_deform_attn_cuda.cu
+360
-0
mmcv/ops/csrc/parrots/ms_deform_attn_parrots.cpp
mmcv/ops/csrc/parrots/ms_deform_attn_parrots.cpp
+68
-0
mmcv/ops/csrc/pytorch/ms_deform_attn.cpp
mmcv/ops/csrc/pytorch/ms_deform_attn.cpp
+19
-14
mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu
mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu
+3
-8
mmcv/ops/csrc/pytorch/pybind.cpp
mmcv/ops/csrc/pytorch/pybind.cpp
+9
-8
mmcv/ops/multi_scale_deform_attn.py
mmcv/ops/multi_scale_deform_attn.py
+22
-14
mmcv/utils/ext_loader.py
mmcv/utils/ext_loader.py
+32
-8
tests/test_ops/test_ms_deformable_attn.py
tests/test_ops/test_ms_deformable_attn.py
+17
-6
No files found.
mmcv/ops/csrc/parrots/ms_deform_attn.cpp
0 → 100644
View file @
732ff509
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
Tensor
ms_deform_attn_cuda_forward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
int
im2col_step
);
void
ms_deform_attn_cuda_backward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
Tensor
&
grad_sampling_loc
,
Tensor
&
grad_attn_weight
,
const
int
im2col_step
);
#endif
Tensor
ms_deform_attn_forward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
int
im2col_step
)
{
if
(
value
.
type
().
is_cuda
())
{
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT
(
value
)
CHECK_CUDA_INPUT
(
spatial_shapes
)
CHECK_CUDA_INPUT
(
level_start_index
)
CHECK_CUDA_INPUT
(
sampling_loc
)
CHECK_CUDA_INPUT
(
attn_weight
)
return
ms_deform_attn_cuda_forward
(
value
,
spatial_shapes
,
level_start_index
,
sampling_loc
,
attn_weight
,
im2col_step
);
#else
AT_ERROR
(
"Not compiled with GPU support"
);
#endif
}
AT_ERROR
(
"Not implemented on the CPU"
);
}
void
ms_deform_attn_backward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
Tensor
&
grad_sampling_loc
,
Tensor
&
grad_attn_weight
,
const
int
im2col_step
)
{
if
(
value
.
type
().
is_cuda
())
{
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT
(
value
)
CHECK_CUDA_INPUT
(
spatial_shapes
)
CHECK_CUDA_INPUT
(
level_start_index
)
CHECK_CUDA_INPUT
(
sampling_loc
)
CHECK_CUDA_INPUT
(
attn_weight
)
CHECK_CUDA_INPUT
(
grad_output
)
CHECK_CUDA_INPUT
(
grad_value
)
CHECK_CUDA_INPUT
(
grad_sampling_loc
)
CHECK_CUDA_INPUT
(
grad_attn_weight
)
ms_deform_attn_cuda_backward
(
value
,
spatial_shapes
,
level_start_index
,
sampling_loc
,
attn_weight
,
grad_output
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
);
#else
AT_ERROR
(
"Not compiled with GPU support"
);
#endif
}
else
{
AT_ERROR
(
"Not implemented on the CPU"
);
}
}
mmcv/ops/csrc/parrots/ms_deform_attn_cuda.cu
0 → 100644
View file @
732ff509
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <THC/THCAtomics.cuh>
#include <ms_deform_attn_cuda_kernel.cuh>
#include <vector>
template
<
typename
scalar_t
>
void
ms_deformable_im2col_cuda
(
cudaStream_t
stream
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
data_col
)
{
const
int
num_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_actual_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_threads
=
CUDA_NUM_THREADS
;
ms_deformable_im2col_gpu_kernel
<
scalar_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
data_col
);
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in ms_deformable_im2col_cuda: %s
\n
"
,
cudaGetErrorString
(
err
));
}
}
template
<
typename
scalar_t
>
void
ms_deformable_col2im_cuda
(
cudaStream_t
stream
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
const
int
num_threads
=
(
channels
>
CUDA_NUM_THREADS
)
?
CUDA_NUM_THREADS
:
channels
;
const
int
num_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_actual_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
if
(
channels
>
1024
)
{
if
((
channels
&
1023
)
==
0
)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
<
scalar_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
num_threads
*
3
*
sizeof
(
scalar_t
),
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
else
{
ms_deformable_col2im_gpu_kernel_gm
<
scalar_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
}
else
{
switch
(
channels
)
{
case
1
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
1
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
2
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
2
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
4
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
4
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
8
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
8
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
16
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
16
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
32
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
32
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
64
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
64
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
128
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
128
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
256
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
256
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
512
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
512
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
1024
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
1024
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
default:
if
(
channels
<
64
)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v1
<
scalar_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
num_threads
*
3
*
sizeof
(
scalar_t
),
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
else
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2
<
scalar_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
num_threads
*
3
*
sizeof
(
scalar_t
),
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
}
}
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in ms_deformable_col2im_cuda: %s
\n
"
,
cudaGetErrorString
(
err
));
}
}
at
::
Tensor
ms_deform_attn_cuda_forward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
spatial_shapes
,
const
at
::
Tensor
&
level_start_index
,
const
at
::
Tensor
&
sampling_loc
,
const
at
::
Tensor
&
attn_weight
,
const
int
im2col_step
)
{
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
"spatial_shapes tensor has to be contiguous"
);
AT_ASSERTM
(
level_start_index
.
is_contiguous
(),
"level_start_index tensor has to be contiguous"
);
AT_ASSERTM
(
sampling_loc
.
is_contiguous
(),
"sampling_loc tensor has to be contiguous"
);
AT_ASSERTM
(
attn_weight
.
is_contiguous
(),
"attn_weight tensor has to be contiguous"
);
AT_ASSERTM
(
value
.
type
().
is_cuda
(),
"value must be a CUDA tensor"
);
AT_ASSERTM
(
spatial_shapes
.
type
().
is_cuda
(),
"spatial_shapes must be a CUDA tensor"
);
AT_ASSERTM
(
level_start_index
.
type
().
is_cuda
(),
"level_start_index must be a CUDA tensor"
);
AT_ASSERTM
(
sampling_loc
.
type
().
is_cuda
(),
"sampling_loc must be a CUDA tensor"
);
AT_ASSERTM
(
attn_weight
.
type
().
is_cuda
(),
"attn_weight must be a CUDA tensor"
);
const
int
batch
=
value
.
size
(
0
);
const
int
spatial_size
=
value
.
size
(
1
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
channels
=
value
.
size
(
3
);
const
int
num_levels
=
spatial_shapes
.
size
(
0
);
const
int
num_query
=
sampling_loc
.
size
(
1
);
const
int
num_point
=
sampling_loc
.
size
(
4
);
const
int
im2col_step_
=
std
::
min
(
batch
,
im2col_step
);
AT_ASSERTM
(
batch
%
im2col_step_
==
0
,
"batch(%d) must divide im2col_step(%d)"
,
batch
,
im2col_step_
);
auto
output
=
at
::
zeros
({
batch
,
num_query
,
num_heads
,
channels
},
value
.
options
());
const
int
batch_n
=
im2col_step_
;
auto
output_n
=
output
.
view
(
{
batch
/
im2col_step_
,
batch_n
,
num_query
,
num_heads
,
channels
});
auto
per_value_size
=
spatial_size
*
num_heads
*
channels
;
auto
per_sample_loc_size
=
num_query
*
num_heads
*
num_levels
*
num_point
*
2
;
auto
per_attn_weight_size
=
num_query
*
num_heads
*
num_levels
*
num_point
;
for
(
int
n
=
0
;
n
<
batch
/
im2col_step_
;
++
n
)
{
auto
columns
=
output_n
.
select
(
0
,
n
);
AT_DISPATCH_FLOATING_TYPES
(
value
.
type
(),
"ms_deform_attn_forward_cuda"
,
([
&
]
{
ms_deformable_im2col_cuda
(
at
::
cuda
::
getCurrentCUDAStream
(),
value
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_value_size
,
spatial_shapes
.
data
<
int64_t
>
(),
level_start_index
.
data
<
int64_t
>
(),
sampling_loc
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_sample_loc_size
,
attn_weight
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_attn_weight_size
,
batch_n
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
columns
.
data
<
scalar_t
>
());
}));
}
output
=
output
.
view
({
batch
,
num_query
,
num_heads
*
channels
});
return
output
;
}
void
ms_deform_attn_cuda_backward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
spatial_shapes
,
const
at
::
Tensor
&
level_start_index
,
const
at
::
Tensor
&
sampling_loc
,
const
at
::
Tensor
&
attn_weight
,
const
at
::
Tensor
&
grad_output
,
at
::
Tensor
&
grad_value
,
at
::
Tensor
&
grad_sampling_loc
,
at
::
Tensor
&
grad_attn_weight
,
const
int
im2col_step
)
{
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
"spatial_shapes tensor has to be contiguous"
);
AT_ASSERTM
(
level_start_index
.
is_contiguous
(),
"level_start_index tensor has to be contiguous"
);
AT_ASSERTM
(
sampling_loc
.
is_contiguous
(),
"sampling_loc tensor has to be contiguous"
);
AT_ASSERTM
(
attn_weight
.
is_contiguous
(),
"attn_weight tensor has to be contiguous"
);
AT_ASSERTM
(
grad_output
.
is_contiguous
(),
"grad_output tensor has to be contiguous"
);
AT_ASSERTM
(
value
.
type
().
is_cuda
(),
"value must be a CUDA tensor"
);
AT_ASSERTM
(
spatial_shapes
.
type
().
is_cuda
(),
"spatial_shapes must be a CUDA tensor"
);
AT_ASSERTM
(
level_start_index
.
type
().
is_cuda
(),
"level_start_index must be a CUDA tensor"
);
AT_ASSERTM
(
sampling_loc
.
type
().
is_cuda
(),
"sampling_loc must be a CUDA tensor"
);
AT_ASSERTM
(
attn_weight
.
type
().
is_cuda
(),
"attn_weight must be a CUDA tensor"
);
AT_ASSERTM
(
grad_output
.
type
().
is_cuda
(),
"grad_output must be a CUDA tensor"
);
const
int
batch
=
value
.
size
(
0
);
const
int
spatial_size
=
value
.
size
(
1
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
channels
=
value
.
size
(
3
);
const
int
num_levels
=
spatial_shapes
.
size
(
0
);
const
int
num_query
=
sampling_loc
.
size
(
1
);
const
int
num_point
=
sampling_loc
.
size
(
4
);
const
int
im2col_step_
=
std
::
min
(
batch
,
im2col_step
);
AT_ASSERTM
(
batch
%
im2col_step_
==
0
,
"batch(%d) must divide im2col_step(%d)"
,
batch
,
im2col_step_
);
const
int
batch_n
=
im2col_step_
;
auto
per_value_size
=
spatial_size
*
num_heads
*
channels
;
auto
per_sample_loc_size
=
num_query
*
num_heads
*
num_levels
*
num_point
*
2
;
auto
per_attn_weight_size
=
num_query
*
num_heads
*
num_levels
*
num_point
;
auto
grad_output_n
=
grad_output
.
view
(
{
batch
/
im2col_step_
,
batch_n
,
num_query
,
num_heads
,
channels
});
for
(
int
n
=
0
;
n
<
batch
/
im2col_step_
;
++
n
)
{
auto
grad_output_g
=
grad_output_n
.
select
(
0
,
n
);
AT_DISPATCH_FLOATING_TYPES
(
value
.
type
(),
"ms_deform_attn_backward_cuda"
,
([
&
]
{
ms_deformable_col2im_cuda
(
at
::
cuda
::
getCurrentCUDAStream
(),
grad_output_g
.
data
<
scalar_t
>
(),
value
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_value_size
,
spatial_shapes
.
data
<
int64_t
>
(),
level_start_index
.
data
<
int64_t
>
(),
sampling_loc
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_sample_loc_size
,
attn_weight
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_attn_weight_size
,
batch_n
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_value_size
,
grad_sampling_loc
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_sample_loc_size
,
grad_attn_weight
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_attn_weight_size
);
}));
}
}
mmcv/ops/csrc/parrots/ms_deform_attn_parrots.cpp
0 → 100644
View file @
732ff509
#include <torch/extension.h>
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
using
namespace
at
;
using
namespace
parrots
;
Tensor
ms_deform_attn_forward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
int
im2col_step
);
void
ms_deform_attn_backward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
Tensor
&
grad_sampling_loc
,
Tensor
&
grad_attn_weight
,
const
int
im2col_step
);
void
ms_deform_attn_forward_parrots
(
CudaContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
OperatorBase
::
out_list_t
&
outs
)
{
int
im2col_step
;
SSAttrs
(
attr
).
get
<
int
>
(
"im2col_step"
,
im2col_step
).
done
();
const
auto
&
value
=
buildATensor
(
ctx
,
ins
[
0
]);
const
auto
&
spatial_shapes
=
buildATensor
(
ctx
,
ins
[
1
]);
const
auto
&
level_start_index
=
buildATensor
(
ctx
,
ins
[
2
]);
const
auto
&
sampling_loc
=
buildATensor
(
ctx
,
ins
[
3
]);
const
auto
&
attn_weight
=
buildATensor
(
ctx
,
ins
[
4
]);
auto
out
=
ms_deform_attn_forward
(
value
,
spatial_shapes
,
level_start_index
,
sampling_loc
,
attn_weight
,
im2col_step
);
updateDArray
(
ctx
,
out
,
outs
[
0
]);
}
void
ms_deform_attn_backward_parrots
(
CudaContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
OperatorBase
::
out_list_t
&
outs
)
{
int
im2col_step
;
SSAttrs
(
attr
).
get
<
int
>
(
"im2col_step"
,
im2col_step
).
done
();
const
auto
&
value
=
buildATensor
(
ctx
,
ins
[
0
]);
const
auto
&
spatial_shapes
=
buildATensor
(
ctx
,
ins
[
1
]);
const
auto
&
level_start_index
=
buildATensor
(
ctx
,
ins
[
2
]);
const
auto
&
sampling_loc
=
buildATensor
(
ctx
,
ins
[
3
]);
const
auto
&
attn_weight
=
buildATensor
(
ctx
,
ins
[
4
]);
const
auto
&
grad_output
=
buildATensor
(
ctx
,
ins
[
5
]);
auto
grad_value
=
buildATensor
(
ctx
,
outs
[
0
]);
auto
grad_sampling_loc
=
buildATensor
(
ctx
,
outs
[
1
]);
auto
grad_attn_weight
=
buildATensor
(
ctx
,
outs
[
2
]);
ms_deform_attn_backward
(
value
,
spatial_shapes
,
level_start_index
,
sampling_loc
,
attn_weight
,
grad_output
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
);
}
PARROTS_EXTENSION_REGISTER
(
ms_deform_attn_forward
)
.
attr
(
"im2col_step"
)
.
input
(
5
)
.
output
(
1
)
.
apply
(
ms_deform_attn_forward_parrots
)
.
done
();
PARROTS_EXTENSION_REGISTER
(
ms_deform_attn_backward
)
.
attr
(
"im2col_step"
)
.
input
(
6
)
.
output
(
3
)
.
apply
(
ms_deform_attn_backward_parrots
)
.
done
();
mmcv/ops/csrc/pytorch/ms_deform_attn.cpp
View file @
732ff509
...
@@ -19,11 +19,11 @@ Tensor ms_deform_attn_cuda_forward(const Tensor &value,
...
@@ -19,11 +19,11 @@ Tensor ms_deform_attn_cuda_forward(const Tensor &value,
const
Tensor
&
attn_weight
,
const
Tensor
&
attn_weight
,
const
int
im2col_step
);
const
int
im2col_step
);
std
::
vector
<
Tensor
>
ms_deform_attn_cuda_backward
(
void
ms_deform_attn_cuda_backward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
const
int
im2col_step
);
Tensor
&
grad_sampling_loc
,
Tensor
&
grad_attn_weight
,
const
int
im2col_step
);
#endif
#endif
...
@@ -48,13 +48,13 @@ Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
...
@@ -48,13 +48,13 @@ Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
AT_ERROR
(
"Not implemented on the CPU"
);
AT_ERROR
(
"Not implemented on the CPU"
);
}
}
std
::
vector
<
Tensor
>
ms_deform_attn_backward
(
const
Tensor
&
value
,
void
ms_deform_attn_backward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
const
Tensor
&
grad_
output
,
Tensor
&
grad_
sampling_loc
,
const
int
im2col_step
)
{
Tensor
&
grad_attn_weight
,
const
int
im2col_step
)
{
if
(
value
.
type
().
is_cuda
())
{
if
(
value
.
type
().
is_cuda
())
{
#ifdef MMCV_WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT
(
value
)
CHECK_CUDA_INPUT
(
value
)
...
@@ -63,12 +63,17 @@ std::vector<Tensor> ms_deform_attn_backward(const Tensor &value,
...
@@ -63,12 +63,17 @@ std::vector<Tensor> ms_deform_attn_backward(const Tensor &value,
CHECK_CUDA_INPUT
(
sampling_loc
)
CHECK_CUDA_INPUT
(
sampling_loc
)
CHECK_CUDA_INPUT
(
attn_weight
)
CHECK_CUDA_INPUT
(
attn_weight
)
CHECK_CUDA_INPUT
(
grad_output
)
CHECK_CUDA_INPUT
(
grad_output
)
return
ms_deform_attn_cuda_backward
(
value
,
spatial_shapes
,
CHECK_CUDA_INPUT
(
grad_value
)
level_start_index
,
sampling_loc
,
CHECK_CUDA_INPUT
(
grad_sampling_loc
)
attn_weight
,
grad_output
,
im2col_step
);
CHECK_CUDA_INPUT
(
grad_attn_weight
)
ms_deform_attn_cuda_backward
(
value
,
spatial_shapes
,
level_start_index
,
sampling_loc
,
attn_weight
,
grad_output
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
);
#else
#else
AT_ERROR
(
"Not compiled with GPU support"
);
AT_ERROR
(
"Not compiled with GPU support"
);
#endif
#endif
}
else
{
AT_ERROR
(
"Not implemented on the CPU"
);
}
}
AT_ERROR
(
"Not implemented on the CPU"
);
}
}
mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu
View file @
732ff509
...
@@ -286,11 +286,12 @@ at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
...
@@ -286,11 +286,12 @@ at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
return
output
;
return
output
;
}
}
std
::
vector
<
at
::
Tensor
>
ms_deform_attn_cuda_backward
(
void
ms_deform_attn_cuda_backward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
spatial_shapes
,
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
spatial_shapes
,
const
at
::
Tensor
&
level_start_index
,
const
at
::
Tensor
&
sampling_loc
,
const
at
::
Tensor
&
level_start_index
,
const
at
::
Tensor
&
sampling_loc
,
const
at
::
Tensor
&
attn_weight
,
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
attn_weight
,
const
at
::
Tensor
&
grad_output
,
const
int
im2col_step
)
{
at
::
Tensor
&
grad_value
,
at
::
Tensor
&
grad_sampling_loc
,
at
::
Tensor
&
grad_attn_weight
,
const
int
im2col_step
)
{
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
"spatial_shapes tensor has to be contiguous"
);
"spatial_shapes tensor has to be contiguous"
);
...
@@ -328,10 +329,6 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
...
@@ -328,10 +329,6 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
AT_ASSERTM
(
batch
%
im2col_step_
==
0
,
"batch(%d) must divide im2col_step(%d)"
,
AT_ASSERTM
(
batch
%
im2col_step_
==
0
,
"batch(%d) must divide im2col_step(%d)"
,
batch
,
im2col_step_
);
batch
,
im2col_step_
);
auto
grad_value
=
at
::
zeros_like
(
value
);
auto
grad_sampling_loc
=
at
::
zeros_like
(
sampling_loc
);
auto
grad_attn_weight
=
at
::
zeros_like
(
attn_weight
);
const
int
batch_n
=
im2col_step_
;
const
int
batch_n
=
im2col_step_
;
auto
per_value_size
=
spatial_size
*
num_heads
*
channels
;
auto
per_value_size
=
spatial_size
*
num_heads
*
channels
;
auto
per_sample_loc_size
=
num_query
*
num_heads
*
num_levels
*
num_point
*
2
;
auto
per_sample_loc_size
=
num_query
*
num_heads
*
num_levels
*
num_point
*
2
;
...
@@ -360,6 +357,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
...
@@ -360,6 +357,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
n
*
im2col_step_
*
per_attn_weight_size
);
n
*
im2col_step_
*
per_attn_weight_size
);
}));
}));
}
}
return
{
grad_value
,
grad_sampling_loc
,
grad_attn_weight
};
}
}
mmcv/ops/csrc/pytorch/pybind.cpp
View file @
732ff509
...
@@ -97,13 +97,13 @@ Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
...
@@ -97,13 +97,13 @@ Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
const
Tensor
&
sampling_loc
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
int
im2col_step
);
const
Tensor
&
attn_weight
,
const
int
im2col_step
);
std
::
vector
<
Tensor
>
ms_deform_attn_backward
(
const
Tensor
&
value
,
void
ms_deform_attn_backward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
const
Tensor
&
grad_
output
,
Tensor
&
grad_
sampling_loc
,
const
int
im2col_step
);
Tensor
&
grad_attn_weight
,
const
int
im2col_step
);
Tensor
nms
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
,
int
offset
);
Tensor
nms
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
,
int
offset
);
...
@@ -445,5 +445,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -445,5 +445,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"value"
),
py
::
arg
(
"value_spatial_shapes"
),
py
::
arg
(
"value"
),
py
::
arg
(
"value_spatial_shapes"
),
py
::
arg
(
"value_level_start_index"
),
py
::
arg
(
"sampling_locations"
),
py
::
arg
(
"value_level_start_index"
),
py
::
arg
(
"sampling_locations"
),
py
::
arg
(
"attention_weights"
),
py
::
arg
(
"grad_output"
),
py
::
arg
(
"attention_weights"
),
py
::
arg
(
"grad_output"
),
py
::
arg
(
"im2col_step"
));
py
::
arg
(
"grad_value"
),
py
::
arg
(
"grad_sampling_loc"
),
py
::
arg
(
"grad_attn_weight"
),
py
::
arg
(
"im2col_step"
));
}
}
mmcv/ops/multi_scale_deform_attn.py
View file @
732ff509
...
@@ -35,11 +35,13 @@ class MultiScaleDeformableAttnFunction(Function):
...
@@ -35,11 +35,13 @@ class MultiScaleDeformableAttnFunction(Function):
"""
"""
ctx
.
im2col_step
=
im2col_step
ctx
.
im2col_step
=
im2col_step
output
=
ext_module
.
ms_deform_attn_forward
(
value
,
value_spatial_shapes
,
output
=
ext_module
.
ms_deform_attn_forward
(
value_level_start_index
,
value
,
sampling_locations
,
value_spatial_shapes
,
attention_weights
,
value_level_start_index
,
ctx
.
im2col_step
)
sampling_locations
,
attention_weights
,
im2col_step
=
ctx
.
im2col_step
)
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
value_level_start_index
,
sampling_locations
,
attention_weights
)
attention_weights
)
...
@@ -60,15 +62,21 @@ class MultiScaleDeformableAttnFunction(Function):
...
@@ -60,15 +62,21 @@ class MultiScaleDeformableAttnFunction(Function):
"""
"""
value
,
value_spatial_shapes
,
value_level_start_index
,
\
value
,
value_spatial_shapes
,
value_level_start_index
,
\
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
grad_value
,
grad_sampling_loc
,
grad_attn_weight
=
\
grad_value
=
torch
.
zeros_like
(
value
)
ext_module
.
ms_deform_attn_backward
(
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
value
,
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
value_spatial_shapes
,
value_level_start_index
,
ext_module
.
ms_deform_attn_backward
(
sampling_locations
,
value
,
attention_weights
,
value_spatial_shapes
,
grad_output
,
value_level_start_index
,
ctx
.
im2col_step
)
sampling_locations
,
attention_weights
,
grad_output
.
contiguous
(),
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
=
ctx
.
im2col_step
)
return
grad_value
,
None
,
None
,
\
return
grad_value
,
None
,
None
,
\
grad_sampling_loc
,
grad_attn_weight
,
None
grad_sampling_loc
,
grad_attn_weight
,
None
...
...
mmcv/utils/ext_loader.py
View file @
732ff509
...
@@ -16,22 +16,46 @@ else:
...
@@ -16,22 +16,46 @@ else:
from
parrots
import
extension
from
parrots
import
extension
has_return_value_ops
=
[
has_return_value_ops
=
[
'nms'
,
'softnms'
,
'nms_match'
,
'nms_rotated'
,
'top_pool_forward'
,
'nms'
,
'top_pool_backward'
,
'bottom_pool_forward'
,
'bottom_pool_backward'
,
'softnms'
,
'left_pool_forward'
,
'left_pool_backward'
,
'right_pool_forward'
,
'nms_match'
,
'right_pool_backward'
,
'fused_bias_leakyrelu'
,
'upfirdn2d'
'nms_rotated'
,
'top_pool_forward'
,
'top_pool_backward'
,
'bottom_pool_forward'
,
'bottom_pool_backward'
,
'left_pool_forward'
,
'left_pool_backward'
,
'right_pool_forward'
,
'right_pool_backward'
,
'fused_bias_leakyrelu'
,
'upfirdn2d'
,
'ms_deform_attn_forward'
,
]
]
def
get_fake_func
(
name
):
def
fake_func
(
*
args
,
**
kwargs
):
raise
RuntimeError
(
'{} is not supported in parrots now'
.
format
(
name
))
return
fake_func
def
load_ext
(
name
,
funcs
):
def
load_ext
(
name
,
funcs
):
ExtModule
=
namedtuple
(
'ExtModule'
,
funcs
)
ExtModule
=
namedtuple
(
'ExtModule'
,
funcs
)
ext_list
=
[]
ext_list
=
[]
lib_root
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)))
lib_root
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)))
for
fun
in
funcs
:
for
fun
in
funcs
:
if
fun
in
has_return_value_ops
:
try
:
ext_list
.
append
(
extension
.
load
(
fun
,
name
,
lib_dir
=
lib_root
).
op
)
ext_fun
=
extension
.
load
(
fun
,
name
,
lib_dir
=
lib_root
)
except
Exception
:
ext_fun
=
get_fake_func
(
fun
)
ext_list
.
append
(
ext_fun
)
else
:
else
:
ext_list
.
append
(
if
fun
in
has_return_value_ops
:
extension
.
load
(
fun
,
name
,
lib_dir
=
lib_root
).
op_
)
ext_list
.
append
(
ext_fun
.
op
)
else
:
ext_list
.
append
(
ext_fun
.
op_
)
return
ExtModule
(
*
ext_list
)
return
ExtModule
(
*
ext_list
)
...
...
tests/test_ops/test_ms_deformable_attn.py
View file @
732ff509
import
pytest
import
pytest
import
torch
import
torch
from
torch.autograd
import
gradcheck
from
mmcv.ops.multi_scale_deform_attn
import
(
from
mmcv.ops.multi_scale_deform_attn
import
(
MultiScaleDeformableAttnFunction
,
multi_scale_deformable_attn_pytorch
)
MultiScaleDeformableAttnFunction
,
multi_scale_deformable_attn_pytorch
)
_USING_PARROTS
=
True
try
:
from
parrots.autograd
import
gradcheck
except
ImportError
:
from
torch.autograd
import
gradcheck
_USING_PARROTS
=
False
def
test_forward_multi_scale_deformable_attn_pytorch
():
def
test_forward_multi_scale_deformable_attn_pytorch
():
N
,
M
,
D
=
1
,
2
,
2
N
,
M
,
D
=
1
,
2
,
2
...
@@ -118,8 +124,13 @@ def test_gradient_numerical(channels,
...
@@ -118,8 +124,13 @@ def test_gradient_numerical(channels,
value
.
requires_grad
=
grad_value
value
.
requires_grad
=
grad_value
sampling_locations
.
requires_grad
=
grad_sampling_loc
sampling_locations
.
requires_grad
=
grad_sampling_loc
attention_weights
.
requires_grad
=
grad_attn_weight
attention_weights
.
requires_grad
=
grad_attn_weight
if
_USING_PARROTS
:
assert
gradcheck
(
assert
gradcheck
(
func
,
func
,
(
value
.
double
(),
shapes
,
level_start_index
,
(
value
.
double
(),
shapes
,
level_start_index
,
sampling_locations
.
double
(),
attention_weights
.
double
(),
sampling_locations
.
double
(),
attention_weights
.
double
(),
im2col_step
))
im2col_step
),
no_grads
=
[
shapes
,
level_start_index
])
else
:
assert
gradcheck
(
func
,
(
value
.
double
(),
shapes
,
level_start_index
,
sampling_locations
.
double
(),
attention_weights
.
double
(),
im2col_step
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment