Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
d929fa41
Unverified
Commit
d929fa41
authored
Mar 24, 2022
by
q.yao
Committed by
GitHub
Mar 24, 2022
Browse files
Fix ms deform attn (#1823)
* rename grad_sampling_loc and grad_attn_weight * recover cache initialize
parent
5b5d0c15
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
79 deletions
+73
-79
mmcv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh
mmcv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh
+69
-65
mmcv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu
mmcv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu
+4
-14
No files found.
mmcv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh
View file @
d929fa41
...
@@ -14,8 +14,6 @@
...
@@ -14,8 +14,6 @@
#include "common_cuda_helper.hpp"
#include "common_cuda_helper.hpp"
#include "pytorch_cuda_helper.hpp"
#include "pytorch_cuda_helper.hpp"
const
int
CUDA_NUM_THREADS
=
1024
;
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__device__
scalar_t
ms_deform_attn_im2col_bilinear
(
__device__
scalar_t
ms_deform_attn_im2col_bilinear
(
const
scalar_t
*&
bottom_data
,
const
int
&
height
,
const
int
&
width
,
const
scalar_t
*&
bottom_data
,
const
int
&
height
,
const
int
&
width
,
...
@@ -264,10 +262,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
...
@@ -264,10 +262,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
scalar_t
*
grad_attn_weight
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
n
)
{
__shared__
scalar_t
cache_grad_sampling_loc
[
blockSize
*
2
];
__shared__
scalar_t
cache_grad_sampling_loc
[
blockSize
*
2
];
__shared__
scalar_t
cache_grad_attn_weight
[
blockSize
];
__shared__
scalar_t
cache_grad_attn_weight
[
blockSize
];
unsigned
int
tid
=
threadIdx
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
const
int
qid_stride
=
num_heads
*
channels
;
CUDA_1D_KERNEL_LOOP
(
index
,
n
)
{
int
_temp
=
index
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
_temp
/=
channels
;
...
@@ -282,11 +281,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
...
@@ -282,11 +281,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
scalar_t
*
grad_sampling_loc_out
=
grad_attn_weight
+=
grad_sampling_ptr
;
grad_sampling_loc
+
(
grad_sampling_ptr
<<
1
);
scalar_t
*
grad_attn_weight_out
=
grad_attn_weight
+
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
...
@@ -323,23 +322,23 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
...
@@ -323,23 +322,23 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
_grad_h
=
cache_grad_sampling_loc
[
1
],
_grad_h
=
cache_grad_sampling_loc
[
1
],
_grad_a
=
cache_grad_attn_weight
[
0
];
_grad_a
=
cache_grad_attn_weight
[
0
];
int
sid
=
2
;
int
sid
=
2
;
for
(
unsigned
int
tid
=
1
;
tid
<
blockSize
;
++
tid
)
{
for
(
unsigned
int
_
tid
=
1
;
_
tid
<
blockSize
;
++
_
tid
)
{
_grad_w
+=
cache_grad_sampling_loc
[
sid
];
_grad_w
+=
cache_grad_sampling_loc
[
sid
];
_grad_h
+=
cache_grad_sampling_loc
[
sid
+
1
];
_grad_h
+=
cache_grad_sampling_loc
[
sid
+
1
];
_grad_a
+=
cache_grad_attn_weight
[
tid
];
_grad_a
+=
cache_grad_attn_weight
[
_
tid
];
sid
+=
2
;
sid
+=
2
;
}
}
*
grad_sampling_loc
=
_grad_w
;
*
grad_sampling_loc
_out
=
_grad_w
;
*
(
grad_sampling_loc
+
1
)
=
_grad_h
;
*
(
grad_sampling_loc
_out
+
1
)
=
_grad_h
;
*
grad_attn_weight
=
_grad_a
;
*
grad_attn_weight
_out
=
_grad_a
;
}
}
__syncthreads
();
__syncthreads
();
data_weight_ptr
+=
1
;
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_attn_weight
_out
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
grad_sampling_loc
_out
+=
grad_loc_stride
;
}
}
}
}
}
}
...
@@ -354,10 +353,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
...
@@ -354,10 +353,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
scalar_t
*
grad_attn_weight
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
n
)
{
__shared__
scalar_t
cache_grad_sampling_loc
[
blockSize
*
2
];
__shared__
scalar_t
cache_grad_sampling_loc
[
blockSize
*
2
];
__shared__
scalar_t
cache_grad_attn_weight
[
blockSize
];
__shared__
scalar_t
cache_grad_attn_weight
[
blockSize
];
unsigned
int
tid
=
threadIdx
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
CUDA_1D_KERNEL_LOOP
(
index
,
n
)
{
int
_temp
=
index
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
_temp
/=
channels
;
...
@@ -372,8 +371,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
...
@@ -372,8 +371,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
scalar_t
*
grad_sampling_loc_out
=
grad_attn_weight
+=
grad_sampling_ptr
;
grad_sampling_loc
+
(
grad_sampling_ptr
<<
1
);
scalar_t
*
grad_attn_weight_out
=
grad_attn_weight
+
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
qid_stride
=
num_heads
*
channels
;
...
@@ -422,16 +422,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
...
@@ -422,16 +422,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
}
}
if
(
tid
==
0
)
{
if
(
tid
==
0
)
{
*
grad_sampling_loc
=
cache_grad_sampling_loc
[
0
];
*
grad_sampling_loc
_out
=
cache_grad_sampling_loc
[
0
];
*
(
grad_sampling_loc
+
1
)
=
cache_grad_sampling_loc
[
1
];
*
(
grad_sampling_loc
_out
+
1
)
=
cache_grad_sampling_loc
[
1
];
*
grad_attn_weight
=
cache_grad_attn_weight
[
0
];
*
grad_attn_weight
_out
=
cache_grad_attn_weight
[
0
];
}
}
__syncthreads
();
__syncthreads
();
data_weight_ptr
+=
1
;
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_attn_weight
_out
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
grad_sampling_loc
_out
+=
grad_loc_stride
;
}
}
}
}
}
}
...
@@ -446,11 +446,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
...
@@ -446,11 +446,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
scalar_t
*
grad_attn_weight
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
n
)
{
extern
__shared__
int
_s
[];
extern
__shared__
int
_s
[];
scalar_t
*
cache_grad_sampling_loc
=
reinterpret_cast
<
scalar_t
*>
(
_s
);
scalar_t
*
cache_grad_sampling_loc
=
reinterpret_cast
<
scalar_t
*>
(
_s
);
scalar_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
scalar_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
CUDA_1D_KERNEL_LOOP
(
index
,
n
)
{
int
_temp
=
index
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
_temp
/=
channels
;
...
@@ -465,8 +465,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
...
@@ -465,8 +465,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
scalar_t
*
grad_sampling_loc_out
=
grad_attn_weight
+=
grad_sampling_ptr
;
grad_sampling_loc
+
(
grad_sampling_ptr
<<
1
);
scalar_t
*
grad_attn_weight_out
=
grad_attn_weight
+
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
qid_stride
=
num_heads
*
channels
;
...
@@ -506,23 +507,23 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
...
@@ -506,23 +507,23 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
_grad_h
=
cache_grad_sampling_loc
[
1
],
_grad_h
=
cache_grad_sampling_loc
[
1
],
_grad_a
=
cache_grad_attn_weight
[
0
];
_grad_a
=
cache_grad_attn_weight
[
0
];
int
sid
=
2
;
int
sid
=
2
;
for
(
unsigned
int
tid
=
1
;
tid
<
blockDim
.
x
;
++
tid
)
{
for
(
unsigned
int
_
tid
=
1
;
_
tid
<
blockDim
.
x
;
++
_
tid
)
{
_grad_w
+=
cache_grad_sampling_loc
[
sid
];
_grad_w
+=
cache_grad_sampling_loc
[
sid
];
_grad_h
+=
cache_grad_sampling_loc
[
sid
+
1
];
_grad_h
+=
cache_grad_sampling_loc
[
sid
+
1
];
_grad_a
+=
cache_grad_attn_weight
[
tid
];
_grad_a
+=
cache_grad_attn_weight
[
_
tid
];
sid
+=
2
;
sid
+=
2
;
}
}
*
grad_sampling_loc
=
_grad_w
;
*
grad_sampling_loc
_out
=
_grad_w
;
*
(
grad_sampling_loc
+
1
)
=
_grad_h
;
*
(
grad_sampling_loc
_out
+
1
)
=
_grad_h
;
*
grad_attn_weight
=
_grad_a
;
*
grad_attn_weight
_out
=
_grad_a
;
}
}
__syncthreads
();
__syncthreads
();
data_weight_ptr
+=
1
;
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_attn_weight
_out
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
grad_sampling_loc
_out
+=
grad_loc_stride
;
}
}
}
}
}
}
...
@@ -537,11 +538,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
...
@@ -537,11 +538,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
scalar_t
*
grad_attn_weight
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
n
)
{
extern
__shared__
int
_s
[];
extern
__shared__
int
_s
[];
scalar_t
*
cache_grad_sampling_loc
=
reinterpret_cast
<
scalar_t
*>
(
_s
);
scalar_t
*
cache_grad_sampling_loc
=
reinterpret_cast
<
scalar_t
*>
(
_s
);
scalar_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
scalar_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
CUDA_1D_KERNEL_LOOP
(
index
,
n
)
{
int
_temp
=
index
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
_temp
/=
channels
;
...
@@ -556,8 +557,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
...
@@ -556,8 +557,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
scalar_t
*
grad_sampling_loc_out
=
grad_attn_weight
+=
grad_sampling_ptr
;
grad_sampling_loc
+
(
grad_sampling_ptr
<<
1
);
scalar_t
*
grad_attn_weight_out
=
grad_attn_weight
+
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
qid_stride
=
num_heads
*
channels
;
...
@@ -615,16 +617,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
...
@@ -615,16 +617,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
}
}
if
(
tid
==
0
)
{
if
(
tid
==
0
)
{
*
grad_sampling_loc
=
cache_grad_sampling_loc
[
0
];
*
grad_sampling_loc
_out
=
cache_grad_sampling_loc
[
0
];
*
(
grad_sampling_loc
+
1
)
=
cache_grad_sampling_loc
[
1
];
*
(
grad_sampling_loc
_out
+
1
)
=
cache_grad_sampling_loc
[
1
];
*
grad_attn_weight
=
cache_grad_attn_weight
[
0
];
*
grad_attn_weight
_out
=
cache_grad_attn_weight
[
0
];
}
}
__syncthreads
();
__syncthreads
();
data_weight_ptr
+=
1
;
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_attn_weight
_out
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
grad_sampling_loc
_out
+=
grad_loc_stride
;
}
}
}
}
}
}
...
@@ -639,11 +641,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
...
@@ -639,11 +641,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
scalar_t
*
grad_attn_weight
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
n
)
{
extern
__shared__
int
_s
[];
extern
__shared__
int
_s
[];
scalar_t
*
cache_grad_sampling_loc
=
reinterpret_cast
<
scalar_t
*>
(
_s
);
scalar_t
*
cache_grad_sampling_loc
=
reinterpret_cast
<
scalar_t
*>
(
_s
);
scalar_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
scalar_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
CUDA_1D_KERNEL_LOOP
(
index
,
n
)
{
int
_temp
=
index
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
_temp
/=
channels
;
...
@@ -658,8 +660,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
...
@@ -658,8 +660,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
scalar_t
*
grad_sampling_loc_out
=
grad_attn_weight
+=
grad_sampling_ptr
;
grad_sampling_loc
+
(
grad_sampling_ptr
<<
1
);
scalar_t
*
grad_attn_weight_out
=
grad_attn_weight
+
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
qid_stride
=
num_heads
*
channels
;
...
@@ -717,16 +720,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
...
@@ -717,16 +720,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
}
}
if
(
tid
==
0
)
{
if
(
tid
==
0
)
{
atomicAdd
(
grad_sampling_loc
,
cache_grad_sampling_loc
[
0
]);
atomicAdd
(
grad_sampling_loc
_out
,
cache_grad_sampling_loc
[
0
]);
atomicAdd
(
grad_sampling_loc
+
1
,
cache_grad_sampling_loc
[
1
]);
atomicAdd
(
grad_sampling_loc
_out
+
1
,
cache_grad_sampling_loc
[
1
]);
atomicAdd
(
grad_attn_weight
,
cache_grad_attn_weight
[
0
]);
atomicAdd
(
grad_attn_weight
_out
,
cache_grad_attn_weight
[
0
]);
}
}
__syncthreads
();
__syncthreads
();
data_weight_ptr
+=
1
;
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_attn_weight
_out
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
grad_sampling_loc
_out
+=
grad_loc_stride
;
}
}
}
}
}
}
...
@@ -756,8 +759,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(
...
@@ -756,8 +759,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
scalar_t
*
grad_sampling_loc_out
=
grad_attn_weight
+=
grad_sampling_ptr
;
grad_sampling_loc
+
(
grad_sampling_ptr
<<
1
);
scalar_t
*
grad_attn_weight_out
=
grad_attn_weight
+
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
qid_stride
=
num_heads
*
channels
;
...
@@ -784,12 +788,12 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(
...
@@ -784,12 +788,12 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(
ms_deform_attn_col2im_bilinear_gm
(
ms_deform_attn_col2im_bilinear_gm
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
grad_sampling_loc
,
grad_attn_weight
);
grad_sampling_loc
_out
,
grad_attn_weight
_out
);
}
}
data_weight_ptr
+=
1
;
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_attn_weight
_out
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
grad_sampling_loc
_out
+=
grad_loc_stride
;
}
}
}
}
}
}
...
...
mmcv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu
View file @
d929fa41
...
@@ -31,7 +31,7 @@ void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value,
...
@@ -31,7 +31,7 @@ void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value,
const
int
num_point
,
scalar_t
*
data_col
)
{
const
int
num_point
,
scalar_t
*
data_col
)
{
const
int
num_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_actual_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_actual_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_threads
=
CUDA_NUM_THREADS
;
const
int
num_threads
=
THREADS_PER_BLOCK
;
ms_deformable_im2col_gpu_kernel
<
scalar_t
>
ms_deformable_im2col_gpu_kernel
<
scalar_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
num_kernels
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
...
@@ -54,11 +54,11 @@ void ms_deformable_col2im_cuda(
...
@@ -54,11 +54,11 @@ void ms_deformable_col2im_cuda(
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
scalar_t
*
grad_attn_weight
)
{
const
int
num_threads
=
const
int
num_threads
=
(
channels
>
CUDA_NUM_THREADS
)
?
CUDA_NUM_THREADS
:
channels
;
(
channels
>
THREADS_PER_BLOCK
)
?
THREADS_PER_BLOCK
:
channels
;
const
int
num_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_actual_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_actual_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
if
(
channels
>
1024
)
{
if
(
channels
>
THREADS_PER_BLOCK
)
{
if
((
channels
&
1023
)
==
0
)
{
if
((
channels
&
THREADS_PER_BLOCK
-
1
)
==
0
)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
<
scalar_t
>
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
<
scalar_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
num_threads
*
3
*
sizeof
(
scalar_t
),
stream
>>>
(
num_threads
*
3
*
sizeof
(
scalar_t
),
stream
>>>
(
...
@@ -178,16 +178,6 @@ void ms_deformable_col2im_cuda(
...
@@ -178,16 +178,6 @@ void ms_deformable_col2im_cuda(
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
grad_sampling_loc
,
grad_attn_weight
);
break
;
break
;
case
1024
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
1024
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
default:
default:
if
(
channels
<
64
)
{
if
(
channels
<
64
)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v1
<
scalar_t
>
ms_deformable_col2im_gpu_kernel_shm_reduce_v1
<
scalar_t
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment