Unverified Commit 6ce278bb authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Fixes deform_conv issue with large input/output (#4351)



* WIP on fixing index overflow issue

* Fixed backward pass for large num_kernels

* Fixed clang formatting

* Fixed GET_BLOCKS int/int64_t types issue
Co-authored-by: default avatarvfdev-5 <vfdev-5@gmail.com>
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent d9e6d60f
...@@ -3,10 +3,12 @@ ...@@ -3,10 +3,12 @@
namespace vision { namespace vision {
namespace ops { namespace ops {
#define CUDA_1D_KERNEL_LOOP(i, n) \ #define CUDA_1D_KERNEL_LOOP_T(i, n, index_t) \
for (int i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \ for (index_t i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \
i += (blockDim.x * gridDim.x)) i += (blockDim.x * gridDim.x))
#define CUDA_1D_KERNEL_LOOP(i, n) CUDA_1D_KERNEL_LOOP_T(i, n, int)
template <typename integer> template <typename integer>
constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) {
return (n + m - 1) / m; return (n + m - 1) / m;
......
...@@ -88,29 +88,26 @@ inline unsigned int GET_THREADS() { ...@@ -88,29 +88,26 @@ inline unsigned int GET_THREADS() {
return 512; return 512;
} }
inline unsigned int GET_BLOCKS( inline unsigned int GET_BLOCKS(const unsigned int THREADS, const int64_t N) {
const unsigned int THREADS, int64_t kMaxGridNum = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
const unsigned int N) { return (unsigned int)std::min(kMaxGridNum, (N + THREADS - 1) / THREADS);
unsigned int kMaxGridNum =
at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
return std::min(kMaxGridNum, (N + THREADS - 1) / THREADS);
} }
template <typename scalar_t> template <typename scalar_t, typename index_t>
__device__ scalar_t bilinear_interpolate( __device__ scalar_t bilinear_interpolate(
const scalar_t* in, const scalar_t* in,
int height, index_t height,
int width, index_t width,
scalar_t h, scalar_t h,
scalar_t w) { scalar_t w) {
if (h <= -1 || height <= h || w <= -1 || width <= w) { if (h <= -1 || height <= h || w <= -1 || width <= w) {
return 0; return 0;
} }
int h_low = floor(h); index_t h_low = floor(h);
int w_low = floor(w); index_t w_low = floor(w);
int h_high = h_low + 1; index_t h_high = h_low + 1;
int w_high = w_low + 1; index_t w_high = w_low + 1;
scalar_t lh = h - h_low; scalar_t lh = h - h_low;
scalar_t lw = w - w_low; scalar_t lw = w - w_low;
...@@ -135,38 +132,38 @@ __device__ scalar_t bilinear_interpolate( ...@@ -135,38 +132,38 @@ __device__ scalar_t bilinear_interpolate(
return val; return val;
} }
template <typename scalar_t> template <typename scalar_t, typename index_t>
__global__ void deformable_im2col_kernel( __global__ void deformable_im2col_kernel(
int n, index_t n,
const scalar_t* input_ptr, const scalar_t* input_ptr,
const scalar_t* offset_ptr, const scalar_t* offset_ptr,
const scalar_t* mask_ptr, const scalar_t* mask_ptr,
int height, index_t height,
int width, index_t width,
int weight_h, index_t weight_h,
int weight_w, index_t weight_w,
int pad_h, index_t pad_h,
int pad_w, index_t pad_w,
int stride_h, index_t stride_h,
int stride_w, index_t stride_w,
int dilation_h, index_t dilation_h,
int dilation_w, index_t dilation_w,
int batch_sz, index_t batch_sz,
int n_in_channels, index_t n_in_channels,
int n_offset_grps, index_t n_offset_grps,
int out_h, index_t out_h,
int out_w, index_t out_w,
bool use_mask, bool use_mask,
scalar_t* columns_ptr) { scalar_t* columns_ptr) {
CUDA_1D_KERNEL_LOOP(index, n) { CUDA_1D_KERNEL_LOOP_T(index, n, index_t) {
const int out_x = index % out_w; const index_t out_x = index % out_w;
const int out_y = (index / out_w) % out_h; const index_t out_y = (index / out_w) % out_h;
const int out_b = (index / (out_w * out_h)) % batch_sz; const index_t out_b = (index / (out_w * out_h)) % batch_sz;
const int in_c = index / (out_w * out_h * batch_sz); const index_t in_c = index / (out_w * out_h * batch_sz);
const int out_c = in_c * weight_h * weight_w; const index_t out_c = in_c * weight_h * weight_w;
int c_per_offset_grp = n_in_channels / n_offset_grps; index_t c_per_offset_grp = n_in_channels / n_offset_grps;
const int grp_idx = in_c / c_per_offset_grp; const index_t grp_idx = in_c / c_per_offset_grp;
columns_ptr += columns_ptr +=
(out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) +
...@@ -185,8 +182,8 @@ __global__ void deformable_im2col_kernel( ...@@ -185,8 +182,8 @@ __global__ void deformable_im2col_kernel(
for (int i = 0; i < weight_h; ++i) { for (int i = 0; i < weight_h; ++i) {
for (int j = 0; j < weight_w; ++j) { for (int j = 0; j < weight_w; ++j) {
const int mask_idx = i * weight_w + j; const index_t mask_idx = i * weight_w + j;
const int offset_idx = 2 * mask_idx; const index_t offset_idx = 2 * mask_idx;
scalar_t mask_value = 1; scalar_t mask_value = 1;
if (use_mask) { if (use_mask) {
...@@ -231,36 +228,75 @@ void deformable_im2col( ...@@ -231,36 +228,75 @@ void deformable_im2col(
int deformable_group, int deformable_group,
bool use_mask, bool use_mask,
at::Tensor data_col) { at::Tensor data_col) {
int num_kernels = n_in_channels * out_h * out_w * parallel_imgs; int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs;
const unsigned int threads = GET_THREADS(); const unsigned int threads = GET_THREADS();
const unsigned int blocks = GET_BLOCKS(threads, num_kernels); const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
AT_DISPATCH_FLOATING_TYPES_AND_HALF( // Checks if we should use 64bits indexing
input.scalar_type(), "deformable_im2col", ([&] { // https://github.com/pytorch/vision/issues/4269
deformable_im2col_kernel<<<blocks, threads>>>( bool use_64bits_indexing = false;
num_kernels, // Checks if num_kernels or columns numel larger than 2 ** 31
input.data_ptr<scalar_t>(), use_64bits_indexing |= num_kernels > (1 << 31);
data_offset.data_ptr<scalar_t>(), use_64bits_indexing |=
data_mask.data_ptr<scalar_t>(), ((int64_t)n_in_channels * weight_h * weight_w * parallel_imgs * out_h *
height, out_w >
width, (1 << 31));
weight_h,
weight_w, if (use_64bits_indexing) {
pad_h, AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pad_w, input.scalar_type(), "deformable_im2col", ([&] {
stride_h, deformable_im2col_kernel<scalar_t, int64_t><<<blocks, threads>>>(
stride_w, num_kernels,
dilation_h, input.data_ptr<scalar_t>(),
dilation_w, data_offset.data_ptr<scalar_t>(),
parallel_imgs, data_mask.data_ptr<scalar_t>(),
n_in_channels, height,
deformable_group, width,
out_h, weight_h,
out_w, weight_w,
use_mask, pad_h,
data_col.data_ptr<scalar_t>()); pad_w,
})); stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
n_in_channels,
deformable_group,
out_h,
out_w,
use_mask,
data_col.data_ptr<scalar_t>());
}));
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "deformable_im2col", ([&] {
deformable_im2col_kernel<scalar_t, int><<<blocks, threads>>>(
num_kernels,
input.data_ptr<scalar_t>(),
data_offset.data_ptr<scalar_t>(),
data_mask.data_ptr<scalar_t>(),
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
n_in_channels,
deformable_group,
out_h,
out_w,
use_mask,
data_col.data_ptr<scalar_t>());
}));
}
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) { if (err != cudaSuccess) {
...@@ -277,39 +313,40 @@ int get_greatest_divisor_below_bound(int n, int bound) { ...@@ -277,39 +313,40 @@ int get_greatest_divisor_below_bound(int n, int bound) {
return 1; return 1;
} }
template <typename scalar_t> template <typename scalar_t, typename index_t>
__global__ void deformable_col2im_kernel( __global__ void deformable_col2im_kernel(
int n, index_t n,
const scalar_t* col, const scalar_t* col,
const scalar_t* offset_ptr, const scalar_t* offset_ptr,
const scalar_t* mask_ptr, const scalar_t* mask_ptr,
int channels, index_t channels,
int height, index_t height,
int width, index_t width,
int kernel_h, index_t kernel_h,
int kernel_w, index_t kernel_w,
int pad_h, index_t pad_h,
int pad_w, index_t pad_w,
int stride_h, index_t stride_h,
int stride_w, index_t stride_w,
int dilation_h, index_t dilation_h,
int dilation_w, index_t dilation_w,
int batch_sz, index_t batch_sz,
int n_offset_grps, index_t n_offset_grps,
int out_h, index_t out_h,
int out_w, index_t out_w,
bool use_mask, bool use_mask,
scalar_t* grad_im) { scalar_t* grad_im) {
CUDA_1D_KERNEL_LOOP(index, n) { CUDA_1D_KERNEL_LOOP_T(index, n, int64_t) {
const int out_x = index % out_w; const index_t out_x = index % out_w;
const int out_y = (index / out_w) % out_h; const index_t out_y = (index / out_w) % out_h;
const int b = (index / (out_w * out_h)) % batch_sz; const index_t b = (index / (out_w * out_h)) % batch_sz;
const int j = (index / (out_w * out_h * batch_sz)) % kernel_w; const index_t j = (index / (out_w * out_h * batch_sz)) % kernel_w;
const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; const index_t i =
const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h;
const index_t c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h);
int c_per_offset_grp = channels / n_offset_grps;
const int offset_grp = c / c_per_offset_grp; index_t c_per_offset_grp = channels / n_offset_grps;
const index_t offset_grp = c / c_per_offset_grp;
offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w *
out_h * out_w; out_h * out_w;
...@@ -319,11 +356,12 @@ __global__ void deformable_col2im_kernel( ...@@ -319,11 +356,12 @@ __global__ void deformable_col2im_kernel(
out_h * out_w; out_h * out_w;
} }
const int mask_idx = i * kernel_w + j; const index_t mask_idx = i * kernel_w + j;
const int offset_idx = 2 * mask_idx; const index_t offset_idx = 2 * mask_idx;
const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; const index_t offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x;
const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; const index_t offset_w_ptr =
((offset_idx + 1) * out_h + out_y) * out_w + out_x;
const scalar_t offset_h = offset_ptr[offset_h_ptr]; const scalar_t offset_h = offset_ptr[offset_h_ptr];
const scalar_t offset_w = offset_ptr[offset_w_ptr]; const scalar_t offset_w = offset_ptr[offset_w_ptr];
...@@ -336,13 +374,13 @@ __global__ void deformable_col2im_kernel( ...@@ -336,13 +374,13 @@ __global__ void deformable_col2im_kernel(
const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
for (int dy = -1; dy <= 1; dy++) { for (index_t dy = -1; dy <= 1; dy++) {
for (int dx = -1; dx <= 1; dx++) { for (index_t dx = -1; dx <= 1; dx++) {
int yp = int(y) + dy; index_t yp = (index_t)y + dy;
int xp = int(x) + dx; index_t xp = (index_t)x + dx;
if (0 <= yp && yp < height && 0 <= xp && xp < width && if (0 <= yp && yp < height && 0 <= xp && xp < width &&
std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { std::abs(y - yp) < 1 && std::abs(x - xp) < 1) {
int grad_pos = ((b * channels + c) * height + yp) * width + xp; index_t grad_pos = ((b * channels + c) * height + yp) * width + xp;
scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));
atomicAdd(grad_im + grad_pos, mask_value * weight * col[index]); atomicAdd(grad_im + grad_pos, mask_value * weight * col[index]);
} }
...@@ -374,37 +412,72 @@ void compute_grad_input( ...@@ -374,37 +412,72 @@ void compute_grad_input(
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w = int out_w =
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
int num_kernels =
channels * weight_h * weight_w * out_h * out_w * parallel_imgs; int64_t num_kernels =
(int64_t)channels * weight_h * weight_w * out_h * out_w * parallel_imgs;
const unsigned int threads = GET_THREADS(); const unsigned int threads = GET_THREADS();
const unsigned int blocks = GET_BLOCKS(threads, num_kernels); const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
AT_DISPATCH_FLOATING_TYPES_AND_HALF( // Checks if we should use 64bits indexing
columns.scalar_type(), "compute_grad_input", ([&] { // https://github.com/pytorch/vision/issues/4269
deformable_col2im_kernel<<<blocks, threads>>>( bool use_64bits_indexing = false;
num_kernels, // Checks if num_kernels or columns numel larger than 2 ** 31
columns.data_ptr<scalar_t>(), use_64bits_indexing |= num_kernels > (1 << 31);
offset.data_ptr<scalar_t>(),
mask.data_ptr<scalar_t>(), if (use_64bits_indexing) {
channels, AT_DISPATCH_FLOATING_TYPES_AND_HALF(
height, columns.scalar_type(), "compute_grad_input", ([&] {
width, deformable_col2im_kernel<scalar_t, int64_t><<<blocks, threads>>>(
weight_h, num_kernels,
weight_w, columns.data_ptr<scalar_t>(),
pad_h, offset.data_ptr<scalar_t>(),
pad_w, mask.data_ptr<scalar_t>(),
stride_h, channels,
stride_w, height,
dilation_h, width,
dilation_w, weight_h,
parallel_imgs, weight_w,
n_offset_grps, pad_h,
out_h, pad_w,
out_w, stride_h,
use_mask, stride_w,
grad_im.data_ptr<scalar_t>()); dilation_h,
})); dilation_w,
parallel_imgs,
n_offset_grps,
out_h,
out_w,
use_mask,
grad_im.data_ptr<scalar_t>());
}));
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "compute_grad_input", ([&] {
deformable_col2im_kernel<scalar_t, int><<<blocks, threads>>>(
num_kernels,
columns.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
mask.data_ptr<scalar_t>(),
channels,
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
n_offset_grps,
out_h,
out_w,
use_mask,
grad_im.data_ptr<scalar_t>());
}));
}
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) { if (err != cudaSuccess) {
...@@ -412,18 +485,18 @@ void compute_grad_input( ...@@ -412,18 +485,18 @@ void compute_grad_input(
} }
} }
template <typename scalar_t> template <typename scalar_t, typename index_t>
__device__ scalar_t get_coordinate_weight( __device__ scalar_t get_coordinate_weight(
const scalar_t* im_data, const scalar_t* im_data,
int height, index_t height,
int width, index_t width,
scalar_t y, scalar_t y,
scalar_t x, scalar_t x,
bool is_y_direction) { bool is_y_direction) {
int y_l = floor(y); index_t y_l = floor(y);
int x_l = floor(x); index_t x_l = floor(x);
int y_h = y_l + 1; index_t y_h = y_l + 1;
int x_h = x_l + 1; index_t x_h = x_l + 1;
bool valid_y_l = 0 <= y_l && y_l < height; bool valid_y_l = 0 <= y_l && y_l < height;
bool valid_y_h = 0 <= y_h && y_h < height; bool valid_y_h = 0 <= y_h && y_h < height;
...@@ -445,47 +518,47 @@ __device__ scalar_t get_coordinate_weight( ...@@ -445,47 +518,47 @@ __device__ scalar_t get_coordinate_weight(
} }
} }
template <typename scalar_t> template <typename scalar_t, typename index_t>
__global__ void deformable_col2im_coord_kernel( __global__ void deformable_col2im_coord_kernel(
int n, index_t n,
const scalar_t* col_ptr, const scalar_t* col_ptr,
const scalar_t* im_ptr, const scalar_t* im_ptr,
const scalar_t* offset_ptr, const scalar_t* offset_ptr,
const scalar_t* mask_ptr, const scalar_t* mask_ptr,
int channels, index_t channels,
int height, index_t height,
int width, index_t width,
int weight_h, index_t weight_h,
int weight_w, index_t weight_w,
int pad_h, index_t pad_h,
int pad_w, index_t pad_w,
int stride_h, index_t stride_h,
int stride_w, index_t stride_w,
int dilation_h, index_t dilation_h,
int dilation_w, index_t dilation_w,
int batch_sz, index_t batch_sz,
int offset_channels, index_t offset_channels,
int n_offset_grps, index_t n_offset_grps,
int out_h, index_t out_h,
int out_w, index_t out_w,
const bool use_mask, const bool use_mask,
scalar_t* grad_offset, scalar_t* grad_offset,
scalar_t* grad_mask) { scalar_t* grad_mask) {
CUDA_1D_KERNEL_LOOP(index, n) { CUDA_1D_KERNEL_LOOP_T(index, n, int64_t) {
scalar_t grad_offset_val = 0; scalar_t grad_offset_val = 0;
scalar_t grad_mask_val = 0; scalar_t grad_mask_val = 0;
int w = index % out_w; index_t w = index % out_w;
int h = (index / out_w) % out_h; index_t h = (index / out_w) % out_h;
int w_w = (index / (out_w * out_h * 2)) % weight_w; index_t w_w = (index / (out_w * out_h * 2)) % weight_w;
int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; index_t w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h;
int c = (index / (out_w * out_h)) % offset_channels; index_t c = (index / (out_w * out_h)) % offset_channels;
int b = index / (out_w * out_h * offset_channels); index_t b = index / (out_w * out_h * offset_channels);
const int offset_grp = c / (2 * weight_h * weight_w); const index_t offset_grp = c / (2 * weight_h * weight_w);
const int col_step = weight_h * weight_w; const index_t col_step = weight_h * weight_w;
int c_per_offset_grp = channels / n_offset_grps; index_t c_per_offset_grp = channels / n_offset_grps;
col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz *
out_w * out_h; out_w * out_h;
...@@ -499,23 +572,24 @@ __global__ void deformable_col2im_coord_kernel( ...@@ -499,23 +572,24 @@ __global__ void deformable_col2im_coord_kernel(
out_h * out_w; out_h * out_w;
} }
const int offset_c = c - offset_grp * 2 * weight_h * weight_w; const index_t offset_c = c - offset_grp * 2 * weight_h * weight_w;
const bool is_y_direction = offset_c % 2 == 0; const bool is_y_direction = offset_c % 2 == 0;
const int c_bound = c_per_offset_grp * weight_h * weight_w; const index_t c_bound = c_per_offset_grp * weight_h * weight_w;
for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { for (index_t col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) {
const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w; const index_t col_pos =
(((col_c * batch_sz + b) * out_h) + h) * out_w + w;
int out_x = col_pos % out_w; index_t out_x = col_pos % out_w;
int out_y = (col_pos / out_w) % out_h; index_t out_y = (col_pos / out_w) % out_h;
int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; index_t j = (col_pos / (out_w * out_h * batch_sz)) % weight_w;
int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; index_t i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h;
const int mask_idx = i * weight_w + j; const index_t mask_idx = i * weight_w + j;
const int offset_h_ptr = const index_t offset_h_ptr =
(((2 * mask_idx) * out_h + out_y) * out_w + out_x); (((2 * mask_idx) * out_h + out_y) * out_w + out_x);
const int offset_w_ptr = const index_t offset_w_ptr =
(((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x);
const scalar_t offset_h = offset_ptr[offset_h_ptr]; const scalar_t offset_h = offset_ptr[offset_h_ptr];
const scalar_t offset_w = offset_ptr[offset_w_ptr]; const scalar_t offset_w = offset_ptr[offset_w_ptr];
...@@ -543,7 +617,7 @@ __global__ void deformable_col2im_coord_kernel( ...@@ -543,7 +617,7 @@ __global__ void deformable_col2im_coord_kernel(
grad_offset[index] = grad_offset_val; grad_offset[index] = grad_offset_val;
if (use_mask && is_y_direction) { if (use_mask && is_y_direction) {
const int idx = const index_t idx =
((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w +
w_w) * w_w) *
out_h + out_h +
...@@ -580,40 +654,81 @@ void compute_grad_offset_and_mask( ...@@ -580,40 +654,81 @@ void compute_grad_offset_and_mask(
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w = int out_w =
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
int num_kernels = int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w *
out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; n_offset_grps * parallel_imgs;
const unsigned int threads = GET_THREADS(); const unsigned int threads = GET_THREADS();
const unsigned int blocks = GET_BLOCKS(threads, num_kernels); const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
AT_DISPATCH_FLOATING_TYPES_AND_HALF( // Checks if we should use 64bits indexing
columns.scalar_type(), "compute_grad_offset_and_mask", ([&] { // https://github.com/pytorch/vision/issues/4269
deformable_col2im_coord_kernel<<<blocks, threads>>>( bool use_64bits_indexing = false;
num_kernels, // Checks if columns numel is larger than 2 ** 31
columns.data_ptr<scalar_t>(), use_64bits_indexing |= num_kernels > (1 << 31);
input.data_ptr<scalar_t>(), use_64bits_indexing |=
offset.data_ptr<scalar_t>(), ((int64_t)channels * weight_h * weight_w * parallel_imgs * out_h * out_w >
mask.data_ptr<scalar_t>(), (1 << 31));
channels,
height, if (use_64bits_indexing) {
width, AT_DISPATCH_FLOATING_TYPES_AND_HALF(
weight_h, columns.scalar_type(), "compute_grad_offset_and_mask", ([&] {
weight_w, deformable_col2im_coord_kernel<scalar_t, int64_t>
pad_h, <<<blocks, threads>>>(
pad_w, num_kernels,
stride_h, columns.data_ptr<scalar_t>(),
stride_w, input.data_ptr<scalar_t>(),
dilation_h, offset.data_ptr<scalar_t>(),
dilation_w, mask.data_ptr<scalar_t>(),
parallel_imgs, channels,
2 * weight_h * weight_w * n_offset_grps, height,
n_offset_grps, width,
out_h, weight_h,
out_w, weight_w,
use_mask, pad_h,
grad_offset.data_ptr<scalar_t>(), pad_w,
grad_mask.data_ptr<scalar_t>()); stride_h,
})); stride_w,
dilation_h,
dilation_w,
parallel_imgs,
2 * weight_h * weight_w * n_offset_grps,
n_offset_grps,
out_h,
out_w,
use_mask,
grad_offset.data_ptr<scalar_t>(),
grad_mask.data_ptr<scalar_t>());
}));
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "compute_grad_offset_and_mask", ([&] {
deformable_col2im_coord_kernel<scalar_t, int><<<blocks, threads>>>(
num_kernels,
columns.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
mask.data_ptr<scalar_t>(),
channels,
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
2 * weight_h * weight_w * n_offset_grps,
n_offset_grps,
out_h,
out_w,
use_mask,
grad_offset.data_ptr<scalar_t>(),
grad_mask.data_ptr<scalar_t>());
}));
}
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) { if (err != cudaSuccess) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment