Commit 7739077a authored by Josh Fromm's avatar Josh Fromm Committed by Facebook GitHub Bot
Browse files

Hipify various dependencies to enable AMD Face Enhancer

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/675

This diff extends several targets to be hip compatible and fixes a few silly hipification issues with those targets.

After these changes, all dependencies needed for the face enhancer can compile with AMD.

A few silly issues that I had to hack around, maybe we could improve hipification to avoid similar issues in the future:
* Some of the dependencies used sources in `src/cuda/**.cu`. Hipification tried to rename "cuda" to "hip" and broke the paths. I'm not sure where that rename happens so I just changed the directory from "cuda" to "gpu" to avoid the issue.
* One header import called `THCAtomics.cuh` was incorrectly being renamed to `THHAtomics.cuh`, which doesnt exist. Fortunately an equivalent import that doesnt have name issues was available.

We also might want to consider graduating the cpp_library_hip bazel helper out of fbgemm since it seems pretty generally useful.

For some of the targets, we needed to build a python cpp extension, which as far as I can tell we didnt have good hipification for yet. I added a new buck rule very similar to our standard cpp_library_hip rule that creates an extension instead. It's a little copy-pasted so if there are cleaner ways to work around this requirement let me know.

Reviewed By: houseroad

Differential Revision: D61080247

fbshipit-source-id: dc6f101eb3eadfd43ef5610c651b1639e4c78ae6
parent e09224b8
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
*/ */
#include <vector> #include <vector>
#include "cuda/ms_deform_im2col_cuda.cuh" #include "ms_deform_im2col_cuda.cuh"
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
at::Tensor ms_deform_attn_cuda_forward( at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value, const at::Tensor &value,
const at::Tensor &spatial_shapes, const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index, const at::Tensor &level_start_index,
const at::Tensor &sampling_loc, const at::Tensor &sampling_loc,
...@@ -50,7 +50,7 @@ at::Tensor ms_deform_attn_cuda_forward( ...@@ -50,7 +50,7 @@ at::Tensor ms_deform_attn_cuda_forward(
const int im2col_step_ = std::min(batch, im2col_step); const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_; const int batch_n = im2col_step_;
...@@ -81,7 +81,7 @@ at::Tensor ms_deform_attn_cuda_forward( ...@@ -81,7 +81,7 @@ at::Tensor ms_deform_attn_cuda_forward(
std::vector<at::Tensor> ms_deform_attn_cuda_backward( std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value, const at::Tensor &value,
const at::Tensor &spatial_shapes, const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index, const at::Tensor &level_start_index,
const at::Tensor &sampling_loc, const at::Tensor &sampling_loc,
...@@ -127,7 +127,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward( ...@@ -127,7 +127,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n) for (int n = 0; n < batch/im2col_step_; ++n)
{ {
auto grad_output_g = grad_output_n.select(0, n); auto grad_output_g = grad_output_n.select(0, n);
...@@ -150,4 +150,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward( ...@@ -150,4 +150,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
return { return {
grad_value, grad_sampling_loc, grad_attn_weight grad_value, grad_sampling_loc, grad_attn_weight
}; };
} }
\ No newline at end of file
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Atomic.cuh>
#include <THC/THCAtomics.cuh>
#define CUDA_KERNEL_LOOP(i, n) \ #define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
...@@ -31,7 +30,7 @@ inline int GET_BLOCKS(const int N, const int num_threads) ...@@ -31,7 +30,7 @@ inline int GET_BLOCKS(const int N, const int num_threads)
template <typename scalar_t> template <typename scalar_t>
__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels, const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c) const scalar_t &h, const scalar_t &w, const int &m, const int &c)
{ {
...@@ -85,12 +84,12 @@ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, ...@@ -85,12 +84,12 @@ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
template <typename scalar_t> template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels, const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c, const scalar_t &h, const scalar_t &w, const int &m, const int &c,
const scalar_t &top_grad, const scalar_t &top_grad,
const scalar_t &attn_weight, const scalar_t &attn_weight,
scalar_t* &grad_value, scalar_t* &grad_value,
scalar_t* grad_sampling_loc, scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight) scalar_t* grad_attn_weight)
{ {
...@@ -140,7 +139,7 @@ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, ...@@ -140,7 +139,7 @@ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
v3 = bottom_data[ptr3]; v3 = bottom_data[ptr3];
grad_h_weight += hw * v3; grad_h_weight += hw * v3;
grad_w_weight -= lh * v3; grad_w_weight -= lh * v3;
atomicAdd(grad_value+ptr3, w3*top_grad_value); atomicAdd(grad_value+ptr3, w3*top_grad_value);
} }
scalar_t v4 = 0; scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) if (h_high <= height - 1 && w_high <= width - 1)
...@@ -160,12 +159,12 @@ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, ...@@ -160,12 +159,12 @@ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
template <typename scalar_t> template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels, const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c, const scalar_t &h, const scalar_t &w, const int &m, const int &c,
const scalar_t &top_grad, const scalar_t &top_grad,
const scalar_t &attn_weight, const scalar_t &attn_weight,
scalar_t* &grad_value, scalar_t* &grad_value,
scalar_t* grad_sampling_loc, scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight) scalar_t* grad_attn_weight)
{ {
...@@ -215,7 +214,7 @@ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, ...@@ -215,7 +214,7 @@ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
v3 = bottom_data[ptr3]; v3 = bottom_data[ptr3];
grad_h_weight += hw * v3; grad_h_weight += hw * v3;
grad_w_weight -= lh * v3; grad_w_weight -= lh * v3;
atomicAdd(grad_value+ptr3, w3*top_grad_value); atomicAdd(grad_value+ptr3, w3*top_grad_value);
} }
scalar_t v4 = 0; scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) if (h_high <= height - 1 && w_high <= width - 1)
...@@ -228,7 +227,7 @@ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, ...@@ -228,7 +227,7 @@ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
} }
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
atomicAdd(grad_attn_weight, top_grad * val); atomicAdd(grad_attn_weight, top_grad * val);
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
} }
...@@ -236,15 +235,15 @@ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, ...@@ -236,15 +235,15 @@ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
template <typename scalar_t> template <typename scalar_t>
__global__ void ms_deformable_im2col_gpu_kernel(const int n, __global__ void ms_deformable_im2col_gpu_kernel(const int n,
const scalar_t *data_value, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight, const scalar_t *data_attn_weight,
const int batch_size, const int batch_size,
const int spatial_size, const int spatial_size,
const int num_heads, const int num_heads,
const int channels, const int channels,
const int num_levels, const int num_levels,
const int num_query, const int num_query,
const int num_point, const int num_point,
...@@ -255,7 +254,7 @@ __global__ void ms_deformable_im2col_gpu_kernel(const int n, ...@@ -255,7 +254,7 @@ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
int _temp = index; int _temp = index;
const int c_col = _temp % channels; const int c_col = _temp % channels;
_temp /= channels; _temp /= channels;
const int sampling_index = _temp; const int sampling_index = _temp;
const int m_col = _temp % num_heads; const int m_col = _temp % num_heads;
_temp /= num_heads; _temp /= num_heads;
const int q_col = _temp % num_query; const int q_col = _temp % num_query;
...@@ -268,7 +267,7 @@ __global__ void ms_deformable_im2col_gpu_kernel(const int n, ...@@ -268,7 +267,7 @@ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
const int qid_stride = num_heads * channels; const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
scalar_t col = 0; scalar_t col = 0;
for (int l_col=0; l_col < num_levels; ++l_col) for (int l_col=0; l_col < num_levels; ++l_col)
{ {
const int level_start_id = data_level_start_index[l_col]; const int level_start_id = data_level_start_index[l_col];
...@@ -303,13 +302,13 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(co ...@@ -303,13 +302,13 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(co
const scalar_t *grad_col, const scalar_t *grad_col,
const scalar_t *data_value, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight, const scalar_t *data_attn_weight,
const int batch_size, const int batch_size,
const int spatial_size, const int spatial_size,
const int num_heads, const int num_heads,
const int channels, const int channels,
const int num_levels, const int num_levels,
const int num_query, const int num_query,
const int num_point, const int num_point,
...@@ -325,7 +324,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(co ...@@ -325,7 +324,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(co
int _temp = index; int _temp = index;
const int c_col = _temp % channels; const int c_col = _temp % channels;
_temp /= channels; _temp /= channels;
const int sampling_index = _temp; const int sampling_index = _temp;
const int m_col = _temp % num_heads; const int m_col = _temp % num_heads;
_temp /= num_heads; _temp /= num_heads;
const int q_col = _temp % num_query; const int q_col = _temp % num_query;
...@@ -369,10 +368,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(co ...@@ -369,10 +368,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(co
{ {
ms_deform_attn_col2im_bilinear( ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
} }
__syncthreads(); __syncthreads();
if (tid == 0) if (tid == 0)
{ {
...@@ -385,8 +384,8 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(co ...@@ -385,8 +384,8 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(co
_grad_a += cache_grad_attn_weight[tid]; _grad_a += cache_grad_attn_weight[tid];
sid += 2; sid += 2;
} }
*grad_sampling_loc = _grad_w; *grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h; *(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a; *grad_attn_weight = _grad_a;
...@@ -408,13 +407,13 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(co ...@@ -408,13 +407,13 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(co
const scalar_t *grad_col, const scalar_t *grad_col,
const scalar_t *data_value, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight, const scalar_t *data_attn_weight,
const int batch_size, const int batch_size,
const int spatial_size, const int spatial_size,
const int num_heads, const int num_heads,
const int channels, const int channels,
const int num_levels, const int num_levels,
const int num_query, const int num_query,
const int num_point, const int num_point,
...@@ -430,7 +429,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(co ...@@ -430,7 +429,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(co
int _temp = index; int _temp = index;
const int c_col = _temp % channels; const int c_col = _temp % channels;
_temp /= channels; _temp /= channels;
const int sampling_index = _temp; const int sampling_index = _temp;
const int m_col = _temp % num_heads; const int m_col = _temp % num_heads;
_temp /= num_heads; _temp /= num_heads;
const int q_col = _temp % num_query; const int q_col = _temp % num_query;
...@@ -474,10 +473,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(co ...@@ -474,10 +473,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(co
{ {
ms_deform_attn_col2im_bilinear( ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
} }
__syncthreads(); __syncthreads();
for (unsigned int s=blockSize/2; s>0; s>>=1) for (unsigned int s=blockSize/2; s>0; s>>=1)
...@@ -493,7 +492,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(co ...@@ -493,7 +492,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(co
} }
if (tid == 0) if (tid == 0)
{ {
*grad_sampling_loc = cache_grad_sampling_loc[0]; *grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0]; *grad_attn_weight = cache_grad_attn_weight[0];
...@@ -515,13 +514,13 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, ...@@ -515,13 +514,13 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
const scalar_t *grad_col, const scalar_t *grad_col,
const scalar_t *data_value, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight, const scalar_t *data_attn_weight,
const int batch_size, const int batch_size,
const int spatial_size, const int spatial_size,
const int num_heads, const int num_heads,
const int channels, const int channels,
const int num_levels, const int num_levels,
const int num_query, const int num_query,
const int num_point, const int num_point,
...@@ -538,7 +537,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, ...@@ -538,7 +537,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
int _temp = index; int _temp = index;
const int c_col = _temp % channels; const int c_col = _temp % channels;
_temp /= channels; _temp /= channels;
const int sampling_index = _temp; const int sampling_index = _temp;
const int m_col = _temp % num_heads; const int m_col = _temp % num_heads;
_temp /= num_heads; _temp /= num_heads;
const int q_col = _temp % num_query; const int q_col = _temp % num_query;
...@@ -582,10 +581,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, ...@@ -582,10 +581,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
{ {
ms_deform_attn_col2im_bilinear( ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
} }
__syncthreads(); __syncthreads();
if (tid == 0) if (tid == 0)
{ {
...@@ -598,8 +597,8 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, ...@@ -598,8 +597,8 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
_grad_a += cache_grad_attn_weight[tid]; _grad_a += cache_grad_attn_weight[tid];
sid += 2; sid += 2;
} }
*grad_sampling_loc = _grad_w; *grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h; *(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a; *grad_attn_weight = _grad_a;
...@@ -620,13 +619,13 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, ...@@ -620,13 +619,13 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
const scalar_t *grad_col, const scalar_t *grad_col,
const scalar_t *data_value, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight, const scalar_t *data_attn_weight,
const int batch_size, const int batch_size,
const int spatial_size, const int spatial_size,
const int num_heads, const int num_heads,
const int channels, const int channels,
const int num_levels, const int num_levels,
const int num_query, const int num_query,
const int num_point, const int num_point,
...@@ -643,7 +642,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, ...@@ -643,7 +642,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
int _temp = index; int _temp = index;
const int c_col = _temp % channels; const int c_col = _temp % channels;
_temp /= channels; _temp /= channels;
const int sampling_index = _temp; const int sampling_index = _temp;
const int m_col = _temp % num_heads; const int m_col = _temp % num_heads;
_temp /= num_heads; _temp /= num_heads;
const int q_col = _temp % num_query; const int q_col = _temp % num_query;
...@@ -687,10 +686,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, ...@@ -687,10 +686,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
{ {
ms_deform_attn_col2im_bilinear( ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
} }
__syncthreads(); __syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
...@@ -706,7 +705,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, ...@@ -706,7 +705,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
} }
} }
__syncthreads(); __syncthreads();
} }
...@@ -733,13 +732,13 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const ...@@ -733,13 +732,13 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const
const scalar_t *grad_col, const scalar_t *grad_col,
const scalar_t *data_value, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight, const scalar_t *data_attn_weight,
const int batch_size, const int batch_size,
const int spatial_size, const int spatial_size,
const int num_heads, const int num_heads,
const int channels, const int channels,
const int num_levels, const int num_levels,
const int num_query, const int num_query,
const int num_point, const int num_point,
...@@ -756,7 +755,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const ...@@ -756,7 +755,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const
int _temp = index; int _temp = index;
const int c_col = _temp % channels; const int c_col = _temp % channels;
_temp /= channels; _temp /= channels;
const int sampling_index = _temp; const int sampling_index = _temp;
const int m_col = _temp % num_heads; const int m_col = _temp % num_heads;
_temp /= num_heads; _temp /= num_heads;
const int q_col = _temp % num_query; const int q_col = _temp % num_query;
...@@ -800,10 +799,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const ...@@ -800,10 +799,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const
{ {
ms_deform_attn_col2im_bilinear( ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
} }
__syncthreads(); __syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
...@@ -847,13 +846,13 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, ...@@ -847,13 +846,13 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
const scalar_t *grad_col, const scalar_t *grad_col,
const scalar_t *data_value, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight, const scalar_t *data_attn_weight,
const int batch_size, const int batch_size,
const int spatial_size, const int spatial_size,
const int num_heads, const int num_heads,
const int channels, const int channels,
const int num_levels, const int num_levels,
const int num_query, const int num_query,
const int num_point, const int num_point,
...@@ -866,7 +865,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, ...@@ -866,7 +865,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
int _temp = index; int _temp = index;
const int c_col = _temp % channels; const int c_col = _temp % channels;
_temp /= channels; _temp /= channels;
const int sampling_index = _temp; const int sampling_index = _temp;
const int m_col = _temp % num_heads; const int m_col = _temp % num_heads;
_temp /= num_heads; _temp /= num_heads;
const int q_col = _temp % num_query; const int q_col = _temp % num_query;
...@@ -907,7 +906,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, ...@@ -907,7 +906,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
{ {
ms_deform_attn_col2im_bilinear_gm( ms_deform_attn_col2im_bilinear_gm(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr, top_grad, weight, grad_value_ptr,
grad_sampling_loc, grad_attn_weight); grad_sampling_loc, grad_attn_weight);
} }
data_weight_ptr += 1; data_weight_ptr += 1;
...@@ -923,15 +922,15 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, ...@@ -923,15 +922,15 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
template <typename scalar_t> template <typename scalar_t>
void ms_deformable_im2col_cuda(cudaStream_t stream, void ms_deformable_im2col_cuda(cudaStream_t stream,
const scalar_t* data_value, const scalar_t* data_value,
const int64_t* data_spatial_shapes, const int64_t* data_spatial_shapes,
const int64_t* data_level_start_index, const int64_t* data_level_start_index,
const scalar_t* data_sampling_loc, const scalar_t* data_sampling_loc,
const scalar_t* data_attn_weight, const scalar_t* data_attn_weight,
const int batch_size, const int batch_size,
const int spatial_size, const int spatial_size,
const int num_heads, const int num_heads,
const int channels, const int channels,
const int num_levels, const int num_levels,
const int num_query, const int num_query,
const int num_point, const int num_point,
scalar_t* data_col) scalar_t* data_col)
...@@ -942,9 +941,9 @@ void ms_deformable_im2col_cuda(cudaStream_t stream, ...@@ -942,9 +941,9 @@ void ms_deformable_im2col_cuda(cudaStream_t stream,
ms_deformable_im2col_gpu_kernel<scalar_t> ms_deformable_im2col_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>( 0, stream>>>(
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) if (err != cudaSuccess)
{ {
...@@ -961,13 +960,13 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -961,13 +960,13 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
const int64_t * data_level_start_index, const int64_t * data_level_start_index,
const scalar_t * data_sampling_loc, const scalar_t * data_sampling_loc,
const scalar_t * data_attn_weight, const scalar_t * data_attn_weight,
const int batch_size, const int batch_size,
const int spatial_size, const int spatial_size,
const int num_heads, const int num_heads,
const int channels, const int channels,
const int num_levels, const int num_levels,
const int num_query, const int num_query,
const int num_point, const int num_point,
scalar_t* grad_value, scalar_t* grad_value,
scalar_t* grad_sampling_loc, scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight) scalar_t* grad_attn_weight)
...@@ -982,17 +981,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -982,17 +981,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t> ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>( num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1005,17 +1004,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1005,17 +1004,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_gm<scalar_t> ms_deformable_col2im_gpu_kernel_gm<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>( 0, stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1031,17 +1030,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1031,17 +1030,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1> ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>( 0, stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1053,17 +1052,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1053,17 +1052,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2> ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>( 0, stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1075,17 +1074,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1075,17 +1074,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4> ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>( 0, stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1097,17 +1096,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1097,17 +1096,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8> ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>( 0, stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1119,17 +1118,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1119,17 +1118,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16> ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>( 0, stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1141,17 +1140,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1141,17 +1140,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32> ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>( 0, stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1163,17 +1162,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1163,17 +1162,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64> ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>( 0, stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1185,17 +1184,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1185,17 +1184,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128> ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>( 0, stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1207,17 +1206,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1207,17 +1206,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256> ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>( 0, stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1229,17 +1228,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1229,17 +1228,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512> ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>( 0, stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1251,17 +1250,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1251,17 +1250,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024> ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>( 0, stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1275,17 +1274,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1275,17 +1274,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t> ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>( num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1298,17 +1297,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1298,17 +1297,17 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t> ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>( num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels, num_kernels,
grad_col, grad_col,
data_value, data_value,
data_spatial_shapes, data_spatial_shapes,
data_level_start_index, data_level_start_index,
data_sampling_loc, data_sampling_loc,
data_attn_weight, data_attn_weight,
batch_size, batch_size,
spatial_size, spatial_size,
num_heads, num_heads,
channels, channels,
num_levels, num_levels,
num_query, num_query,
num_point, num_point,
...@@ -1324,4 +1323,4 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, ...@@ -1324,4 +1323,4 @@ void ms_deformable_col2im_cuda(cudaStream_t stream,
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
} }
} }
\ No newline at end of file
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "cpu/ms_deform_attn_cpu.h" #include "cpu/ms_deform_attn_cpu.h"
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include "cuda/ms_deform_attn_cuda.h" #include "gpu/ms_deform_attn_cuda.h"
#endif #endif
...@@ -59,4 +59,3 @@ ms_deform_attn_backward( ...@@ -59,4 +59,3 @@ ms_deform_attn_backward(
} }
AT_ERROR("Not implemented on the CPU"); AT_ERROR("Not implemented on the CPU");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment