Unverified Commit d929fa41 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Fix ms deform attn (#1823)

* rename grad_sampling_loc and grad_attn_weight

* recover cache initialize
parent 5b5d0c15
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
#include "common_cuda_helper.hpp" #include "common_cuda_helper.hpp"
#include "pytorch_cuda_helper.hpp" #include "pytorch_cuda_helper.hpp"
const int CUDA_NUM_THREADS = 1024;
template <typename scalar_t> template <typename scalar_t>
__device__ scalar_t ms_deform_attn_im2col_bilinear( __device__ scalar_t ms_deform_attn_im2col_bilinear(
const scalar_t *&bottom_data, const int &height, const int &width, const scalar_t *&bottom_data, const int &height, const int &width,
...@@ -264,10 +262,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1( ...@@ -264,10 +262,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
const int channels, const int num_levels, const int num_query, const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) { scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize]; __shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x; unsigned int tid = threadIdx.x;
const int qid_stride = num_heads * channels;
CUDA_1D_KERNEL_LOOP(index, n) {
int _temp = index; int _temp = index;
const int c_col = _temp % channels; const int c_col = _temp % channels;
_temp /= channels; _temp /= channels;
...@@ -282,11 +281,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1( ...@@ -282,11 +281,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
int data_weight_ptr = sampling_index * num_levels * num_point; int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1; int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr; const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1; scalar_t *grad_sampling_loc_out =
grad_attn_weight += grad_sampling_ptr; grad_sampling_loc + (grad_sampling_ptr << 1);
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
const int grad_weight_stride = 1; const int grad_weight_stride = 1;
const int grad_loc_stride = 2; const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) { for (int l_col = 0; l_col < num_levels; ++l_col) {
...@@ -323,23 +322,23 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1( ...@@ -323,23 +322,23 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
_grad_h = cache_grad_sampling_loc[1], _grad_h = cache_grad_sampling_loc[1],
_grad_a = cache_grad_attn_weight[0]; _grad_a = cache_grad_attn_weight[0];
int sid = 2; int sid = 2;
for (unsigned int tid = 1; tid < blockSize; ++tid) { for (unsigned int _tid = 1; _tid < blockSize; ++_tid) {
_grad_w += cache_grad_sampling_loc[sid]; _grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1]; _grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid]; _grad_a += cache_grad_attn_weight[_tid];
sid += 2; sid += 2;
} }
*grad_sampling_loc = _grad_w; *grad_sampling_loc_out = _grad_w;
*(grad_sampling_loc + 1) = _grad_h; *(grad_sampling_loc_out + 1) = _grad_h;
*grad_attn_weight = _grad_a; *grad_attn_weight_out = _grad_a;
} }
__syncthreads(); __syncthreads();
data_weight_ptr += 1; data_weight_ptr += 1;
data_loc_w_ptr += 2; data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride; grad_attn_weight_out += grad_weight_stride;
grad_sampling_loc += grad_loc_stride; grad_sampling_loc_out += grad_loc_stride;
} }
} }
} }
...@@ -354,10 +353,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2( ...@@ -354,10 +353,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
const int channels, const int num_levels, const int num_query, const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) { scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize]; __shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x; unsigned int tid = threadIdx.x;
CUDA_1D_KERNEL_LOOP(index, n) {
int _temp = index; int _temp = index;
const int c_col = _temp % channels; const int c_col = _temp % channels;
_temp /= channels; _temp /= channels;
...@@ -372,8 +371,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2( ...@@ -372,8 +371,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
int data_weight_ptr = sampling_index * num_levels * num_point; int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1; int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr; const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1; scalar_t *grad_sampling_loc_out =
grad_attn_weight += grad_sampling_ptr; grad_sampling_loc + (grad_sampling_ptr << 1);
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
const int grad_weight_stride = 1; const int grad_weight_stride = 1;
const int grad_loc_stride = 2; const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels; const int qid_stride = num_heads * channels;
...@@ -422,16 +422,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2( ...@@ -422,16 +422,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
} }
if (tid == 0) { if (tid == 0) {
*grad_sampling_loc = cache_grad_sampling_loc[0]; *grad_sampling_loc_out = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *(grad_sampling_loc_out + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0]; *grad_attn_weight_out = cache_grad_attn_weight[0];
} }
__syncthreads(); __syncthreads();
data_weight_ptr += 1; data_weight_ptr += 1;
data_loc_w_ptr += 2; data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride; grad_attn_weight_out += grad_weight_stride;
grad_sampling_loc += grad_loc_stride; grad_sampling_loc_out += grad_loc_stride;
} }
} }
} }
...@@ -446,11 +446,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1( ...@@ -446,11 +446,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
const int channels, const int num_levels, const int num_query, const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) { scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[]; extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s); scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x; unsigned int tid = threadIdx.x;
CUDA_1D_KERNEL_LOOP(index, n) {
int _temp = index; int _temp = index;
const int c_col = _temp % channels; const int c_col = _temp % channels;
_temp /= channels; _temp /= channels;
...@@ -465,8 +465,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1( ...@@ -465,8 +465,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
int data_weight_ptr = sampling_index * num_levels * num_point; int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1; int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr; const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1; scalar_t *grad_sampling_loc_out =
grad_attn_weight += grad_sampling_ptr; grad_sampling_loc + (grad_sampling_ptr << 1);
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
const int grad_weight_stride = 1; const int grad_weight_stride = 1;
const int grad_loc_stride = 2; const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels; const int qid_stride = num_heads * channels;
...@@ -506,23 +507,23 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1( ...@@ -506,23 +507,23 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
_grad_h = cache_grad_sampling_loc[1], _grad_h = cache_grad_sampling_loc[1],
_grad_a = cache_grad_attn_weight[0]; _grad_a = cache_grad_attn_weight[0];
int sid = 2; int sid = 2;
for (unsigned int tid = 1; tid < blockDim.x; ++tid) { for (unsigned int _tid = 1; _tid < blockDim.x; ++_tid) {
_grad_w += cache_grad_sampling_loc[sid]; _grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1]; _grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid]; _grad_a += cache_grad_attn_weight[_tid];
sid += 2; sid += 2;
} }
*grad_sampling_loc = _grad_w; *grad_sampling_loc_out = _grad_w;
*(grad_sampling_loc + 1) = _grad_h; *(grad_sampling_loc_out + 1) = _grad_h;
*grad_attn_weight = _grad_a; *grad_attn_weight_out = _grad_a;
} }
__syncthreads(); __syncthreads();
data_weight_ptr += 1; data_weight_ptr += 1;
data_loc_w_ptr += 2; data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride; grad_attn_weight_out += grad_weight_stride;
grad_sampling_loc += grad_loc_stride; grad_sampling_loc_out += grad_loc_stride;
} }
} }
} }
...@@ -537,11 +538,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2( ...@@ -537,11 +538,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
const int channels, const int num_levels, const int num_query, const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) { scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[]; extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s); scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x; unsigned int tid = threadIdx.x;
CUDA_1D_KERNEL_LOOP(index, n) {
int _temp = index; int _temp = index;
const int c_col = _temp % channels; const int c_col = _temp % channels;
_temp /= channels; _temp /= channels;
...@@ -556,8 +557,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2( ...@@ -556,8 +557,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
int data_weight_ptr = sampling_index * num_levels * num_point; int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1; int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr; const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1; scalar_t *grad_sampling_loc_out =
grad_attn_weight += grad_sampling_ptr; grad_sampling_loc + (grad_sampling_ptr << 1);
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
const int grad_weight_stride = 1; const int grad_weight_stride = 1;
const int grad_loc_stride = 2; const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels; const int qid_stride = num_heads * channels;
...@@ -615,16 +617,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2( ...@@ -615,16 +617,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
} }
if (tid == 0) { if (tid == 0) {
*grad_sampling_loc = cache_grad_sampling_loc[0]; *grad_sampling_loc_out = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *(grad_sampling_loc_out + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0]; *grad_attn_weight_out = cache_grad_attn_weight[0];
} }
__syncthreads(); __syncthreads();
data_weight_ptr += 1; data_weight_ptr += 1;
data_loc_w_ptr += 2; data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride; grad_attn_weight_out += grad_weight_stride;
grad_sampling_loc += grad_loc_stride; grad_sampling_loc_out += grad_loc_stride;
} }
} }
} }
...@@ -639,11 +641,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks( ...@@ -639,11 +641,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
const int channels, const int num_levels, const int num_query, const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) { scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[]; extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s); scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x; unsigned int tid = threadIdx.x;
CUDA_1D_KERNEL_LOOP(index, n) {
int _temp = index; int _temp = index;
const int c_col = _temp % channels; const int c_col = _temp % channels;
_temp /= channels; _temp /= channels;
...@@ -658,8 +660,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks( ...@@ -658,8 +660,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
int data_weight_ptr = sampling_index * num_levels * num_point; int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1; int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr; const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1; scalar_t *grad_sampling_loc_out =
grad_attn_weight += grad_sampling_ptr; grad_sampling_loc + (grad_sampling_ptr << 1);
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
const int grad_weight_stride = 1; const int grad_weight_stride = 1;
const int grad_loc_stride = 2; const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels; const int qid_stride = num_heads * channels;
...@@ -717,16 +720,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks( ...@@ -717,16 +720,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
} }
if (tid == 0) { if (tid == 0) {
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); atomicAdd(grad_sampling_loc_out, cache_grad_sampling_loc[0]);
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); atomicAdd(grad_sampling_loc_out + 1, cache_grad_sampling_loc[1]);
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); atomicAdd(grad_attn_weight_out, cache_grad_attn_weight[0]);
} }
__syncthreads(); __syncthreads();
data_weight_ptr += 1; data_weight_ptr += 1;
data_loc_w_ptr += 2; data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride; grad_attn_weight_out += grad_weight_stride;
grad_sampling_loc += grad_loc_stride; grad_sampling_loc_out += grad_loc_stride;
} }
} }
} }
...@@ -756,8 +759,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm( ...@@ -756,8 +759,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(
int data_weight_ptr = sampling_index * num_levels * num_point; int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1; int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr; const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1; scalar_t *grad_sampling_loc_out =
grad_attn_weight += grad_sampling_ptr; grad_sampling_loc + (grad_sampling_ptr << 1);
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
const int grad_weight_stride = 1; const int grad_weight_stride = 1;
const int grad_loc_stride = 2; const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels; const int qid_stride = num_heads * channels;
...@@ -784,12 +788,12 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm( ...@@ -784,12 +788,12 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(
ms_deform_attn_col2im_bilinear_gm( ms_deform_attn_col2im_bilinear_gm(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr, w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
grad_sampling_loc, grad_attn_weight); grad_sampling_loc_out, grad_attn_weight_out);
} }
data_weight_ptr += 1; data_weight_ptr += 1;
data_loc_w_ptr += 2; data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride; grad_attn_weight_out += grad_weight_stride;
grad_sampling_loc += grad_loc_stride; grad_sampling_loc_out += grad_loc_stride;
} }
} }
} }
......
...@@ -31,7 +31,7 @@ void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value, ...@@ -31,7 +31,7 @@ void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value,
const int num_point, scalar_t *data_col) { const int num_point, scalar_t *data_col) {
const int num_kernels = batch_size * num_query * num_heads * channels; const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels;
const int num_threads = CUDA_NUM_THREADS; const int num_threads = THREADS_PER_BLOCK;
ms_deformable_im2col_gpu_kernel<scalar_t> ms_deformable_im2col_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0, stream>>>( <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0, stream>>>(
num_kernels, data_value, data_spatial_shapes, data_level_start_index, num_kernels, data_value, data_spatial_shapes, data_level_start_index,
...@@ -54,11 +54,11 @@ void ms_deformable_col2im_cuda( ...@@ -54,11 +54,11 @@ void ms_deformable_col2im_cuda(
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) { scalar_t *grad_attn_weight) {
const int num_threads = const int num_threads =
(channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels; (channels > THREADS_PER_BLOCK) ? THREADS_PER_BLOCK : channels;
const int num_kernels = batch_size * num_query * num_heads * channels; const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels;
if (channels > 1024) { if (channels > THREADS_PER_BLOCK) {
if ((channels & 1023) == 0) { if ((channels & THREADS_PER_BLOCK - 1) == 0) {
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t> ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(scalar_t), stream>>>( num_threads * 3 * sizeof(scalar_t), stream>>>(
...@@ -178,16 +178,6 @@ void ms_deformable_col2im_cuda( ...@@ -178,16 +178,6 @@ void ms_deformable_col2im_cuda(
channels, num_levels, num_query, num_point, grad_value, channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight); grad_sampling_loc, grad_attn_weight);
break; break;
case 1024:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
1024>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
default: default:
if (channels < 64) { if (channels < 64) {
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t> ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment