Unverified Commit 6ce278bb authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Fixes deform_conv issue with large input/output (#4351)



* WIP on fixing index overflow issue

* Fixed backward pass for large num_kernels

* Fixed clang formatting

* Fixed GET_BLOCKS int/int64_t types issue
Co-authored-by: default avatarvfdev-5 <vfdev-5@gmail.com>
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent d9e6d60f
...@@ -3,10 +3,12 @@ ...@@ -3,10 +3,12 @@
namespace vision { namespace vision {
namespace ops { namespace ops {
#define CUDA_1D_KERNEL_LOOP(i, n) \ #define CUDA_1D_KERNEL_LOOP_T(i, n, index_t) \
for (int i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \ for (index_t i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \
i += (blockDim.x * gridDim.x)) i += (blockDim.x * gridDim.x))
#define CUDA_1D_KERNEL_LOOP(i, n) CUDA_1D_KERNEL_LOOP_T(i, n, int)
template <typename integer> template <typename integer>
constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) {
return (n + m - 1) / m; return (n + m - 1) / m;
......
...@@ -88,29 +88,26 @@ inline unsigned int GET_THREADS() { ...@@ -88,29 +88,26 @@ inline unsigned int GET_THREADS() {
return 512; return 512;
} }
inline unsigned int GET_BLOCKS( inline unsigned int GET_BLOCKS(const unsigned int THREADS, const int64_t N) {
const unsigned int THREADS, int64_t kMaxGridNum = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
const unsigned int N) { return (unsigned int)std::min(kMaxGridNum, (N + THREADS - 1) / THREADS);
unsigned int kMaxGridNum =
at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
return std::min(kMaxGridNum, (N + THREADS - 1) / THREADS);
} }
template <typename scalar_t> template <typename scalar_t, typename index_t>
__device__ scalar_t bilinear_interpolate( __device__ scalar_t bilinear_interpolate(
const scalar_t* in, const scalar_t* in,
int height, index_t height,
int width, index_t width,
scalar_t h, scalar_t h,
scalar_t w) { scalar_t w) {
if (h <= -1 || height <= h || w <= -1 || width <= w) { if (h <= -1 || height <= h || w <= -1 || width <= w) {
return 0; return 0;
} }
int h_low = floor(h); index_t h_low = floor(h);
int w_low = floor(w); index_t w_low = floor(w);
int h_high = h_low + 1; index_t h_high = h_low + 1;
int w_high = w_low + 1; index_t w_high = w_low + 1;
scalar_t lh = h - h_low; scalar_t lh = h - h_low;
scalar_t lw = w - w_low; scalar_t lw = w - w_low;
...@@ -135,38 +132,38 @@ __device__ scalar_t bilinear_interpolate( ...@@ -135,38 +132,38 @@ __device__ scalar_t bilinear_interpolate(
return val; return val;
} }
template <typename scalar_t> template <typename scalar_t, typename index_t>
__global__ void deformable_im2col_kernel( __global__ void deformable_im2col_kernel(
int n, index_t n,
const scalar_t* input_ptr, const scalar_t* input_ptr,
const scalar_t* offset_ptr, const scalar_t* offset_ptr,
const scalar_t* mask_ptr, const scalar_t* mask_ptr,
int height, index_t height,
int width, index_t width,
int weight_h, index_t weight_h,
int weight_w, index_t weight_w,
int pad_h, index_t pad_h,
int pad_w, index_t pad_w,
int stride_h, index_t stride_h,
int stride_w, index_t stride_w,
int dilation_h, index_t dilation_h,
int dilation_w, index_t dilation_w,
int batch_sz, index_t batch_sz,
int n_in_channels, index_t n_in_channels,
int n_offset_grps, index_t n_offset_grps,
int out_h, index_t out_h,
int out_w, index_t out_w,
bool use_mask, bool use_mask,
scalar_t* columns_ptr) { scalar_t* columns_ptr) {
CUDA_1D_KERNEL_LOOP(index, n) { CUDA_1D_KERNEL_LOOP_T(index, n, index_t) {
const int out_x = index % out_w; const index_t out_x = index % out_w;
const int out_y = (index / out_w) % out_h; const index_t out_y = (index / out_w) % out_h;
const int out_b = (index / (out_w * out_h)) % batch_sz; const index_t out_b = (index / (out_w * out_h)) % batch_sz;
const int in_c = index / (out_w * out_h * batch_sz); const index_t in_c = index / (out_w * out_h * batch_sz);
const int out_c = in_c * weight_h * weight_w; const index_t out_c = in_c * weight_h * weight_w;
int c_per_offset_grp = n_in_channels / n_offset_grps; index_t c_per_offset_grp = n_in_channels / n_offset_grps;
const int grp_idx = in_c / c_per_offset_grp; const index_t grp_idx = in_c / c_per_offset_grp;
columns_ptr += columns_ptr +=
(out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) +
...@@ -185,8 +182,8 @@ __global__ void deformable_im2col_kernel( ...@@ -185,8 +182,8 @@ __global__ void deformable_im2col_kernel(
for (int i = 0; i < weight_h; ++i) { for (int i = 0; i < weight_h; ++i) {
for (int j = 0; j < weight_w; ++j) { for (int j = 0; j < weight_w; ++j) {
const int mask_idx = i * weight_w + j; const index_t mask_idx = i * weight_w + j;
const int offset_idx = 2 * mask_idx; const index_t offset_idx = 2 * mask_idx;
scalar_t mask_value = 1; scalar_t mask_value = 1;
if (use_mask) { if (use_mask) {
...@@ -231,14 +228,25 @@ void deformable_im2col( ...@@ -231,14 +228,25 @@ void deformable_im2col(
int deformable_group, int deformable_group,
bool use_mask, bool use_mask,
at::Tensor data_col) { at::Tensor data_col) {
int num_kernels = n_in_channels * out_h * out_w * parallel_imgs; int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs;
const unsigned int threads = GET_THREADS(); const unsigned int threads = GET_THREADS();
const unsigned int blocks = GET_BLOCKS(threads, num_kernels); const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
// Checks if we should use 64bits indexing
// https://github.com/pytorch/vision/issues/4269
bool use_64bits_indexing = false;
// Checks if num_kernels or columns numel larger than 2 ** 31
use_64bits_indexing |= num_kernels > (1 << 31);
use_64bits_indexing |=
((int64_t)n_in_channels * weight_h * weight_w * parallel_imgs * out_h *
out_w >
(1 << 31));
if (use_64bits_indexing) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "deformable_im2col", ([&] { input.scalar_type(), "deformable_im2col", ([&] {
deformable_im2col_kernel<<<blocks, threads>>>( deformable_im2col_kernel<scalar_t, int64_t><<<blocks, threads>>>(
num_kernels, num_kernels,
input.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
data_offset.data_ptr<scalar_t>(), data_offset.data_ptr<scalar_t>(),
...@@ -262,6 +270,34 @@ void deformable_im2col( ...@@ -262,6 +270,34 @@ void deformable_im2col(
data_col.data_ptr<scalar_t>()); data_col.data_ptr<scalar_t>());
})); }));
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "deformable_im2col", ([&] {
deformable_im2col_kernel<scalar_t, int><<<blocks, threads>>>(
num_kernels,
input.data_ptr<scalar_t>(),
data_offset.data_ptr<scalar_t>(),
data_mask.data_ptr<scalar_t>(),
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
n_in_channels,
deformable_group,
out_h,
out_w,
use_mask,
data_col.data_ptr<scalar_t>());
}));
}
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) { if (err != cudaSuccess) {
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
...@@ -277,39 +313,40 @@ int get_greatest_divisor_below_bound(int n, int bound) { ...@@ -277,39 +313,40 @@ int get_greatest_divisor_below_bound(int n, int bound) {
return 1; return 1;
} }
template <typename scalar_t> template <typename scalar_t, typename index_t>
__global__ void deformable_col2im_kernel( __global__ void deformable_col2im_kernel(
int n, index_t n,
const scalar_t* col, const scalar_t* col,
const scalar_t* offset_ptr, const scalar_t* offset_ptr,
const scalar_t* mask_ptr, const scalar_t* mask_ptr,
int channels, index_t channels,
int height, index_t height,
int width, index_t width,
int kernel_h, index_t kernel_h,
int kernel_w, index_t kernel_w,
int pad_h, index_t pad_h,
int pad_w, index_t pad_w,
int stride_h, index_t stride_h,
int stride_w, index_t stride_w,
int dilation_h, index_t dilation_h,
int dilation_w, index_t dilation_w,
int batch_sz, index_t batch_sz,
int n_offset_grps, index_t n_offset_grps,
int out_h, index_t out_h,
int out_w, index_t out_w,
bool use_mask, bool use_mask,
scalar_t* grad_im) { scalar_t* grad_im) {
CUDA_1D_KERNEL_LOOP(index, n) { CUDA_1D_KERNEL_LOOP_T(index, n, int64_t) {
const int out_x = index % out_w; const index_t out_x = index % out_w;
const int out_y = (index / out_w) % out_h; const index_t out_y = (index / out_w) % out_h;
const int b = (index / (out_w * out_h)) % batch_sz; const index_t b = (index / (out_w * out_h)) % batch_sz;
const int j = (index / (out_w * out_h * batch_sz)) % kernel_w; const index_t j = (index / (out_w * out_h * batch_sz)) % kernel_w;
const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; const index_t i =
const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h;
const index_t c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h);
int c_per_offset_grp = channels / n_offset_grps;
const int offset_grp = c / c_per_offset_grp; index_t c_per_offset_grp = channels / n_offset_grps;
const index_t offset_grp = c / c_per_offset_grp;
offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w *
out_h * out_w; out_h * out_w;
...@@ -319,11 +356,12 @@ __global__ void deformable_col2im_kernel( ...@@ -319,11 +356,12 @@ __global__ void deformable_col2im_kernel(
out_h * out_w; out_h * out_w;
} }
const int mask_idx = i * kernel_w + j; const index_t mask_idx = i * kernel_w + j;
const int offset_idx = 2 * mask_idx; const index_t offset_idx = 2 * mask_idx;
const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; const index_t offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x;
const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; const index_t offset_w_ptr =
((offset_idx + 1) * out_h + out_y) * out_w + out_x;
const scalar_t offset_h = offset_ptr[offset_h_ptr]; const scalar_t offset_h = offset_ptr[offset_h_ptr];
const scalar_t offset_w = offset_ptr[offset_w_ptr]; const scalar_t offset_w = offset_ptr[offset_w_ptr];
...@@ -336,13 +374,13 @@ __global__ void deformable_col2im_kernel( ...@@ -336,13 +374,13 @@ __global__ void deformable_col2im_kernel(
const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
for (int dy = -1; dy <= 1; dy++) { for (index_t dy = -1; dy <= 1; dy++) {
for (int dx = -1; dx <= 1; dx++) { for (index_t dx = -1; dx <= 1; dx++) {
int yp = int(y) + dy; index_t yp = (index_t)y + dy;
int xp = int(x) + dx; index_t xp = (index_t)x + dx;
if (0 <= yp && yp < height && 0 <= xp && xp < width && if (0 <= yp && yp < height && 0 <= xp && xp < width &&
std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { std::abs(y - yp) < 1 && std::abs(x - xp) < 1) {
int grad_pos = ((b * channels + c) * height + yp) * width + xp; index_t grad_pos = ((b * channels + c) * height + yp) * width + xp;
scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));
atomicAdd(grad_im + grad_pos, mask_value * weight * col[index]); atomicAdd(grad_im + grad_pos, mask_value * weight * col[index]);
} }
...@@ -374,15 +412,49 @@ void compute_grad_input( ...@@ -374,15 +412,49 @@ void compute_grad_input(
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w = int out_w =
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
int num_kernels =
channels * weight_h * weight_w * out_h * out_w * parallel_imgs; int64_t num_kernels =
(int64_t)channels * weight_h * weight_w * out_h * out_w * parallel_imgs;
const unsigned int threads = GET_THREADS(); const unsigned int threads = GET_THREADS();
const unsigned int blocks = GET_BLOCKS(threads, num_kernels); const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
// Checks if we should use 64bits indexing
// https://github.com/pytorch/vision/issues/4269
bool use_64bits_indexing = false;
// Checks if num_kernels or columns numel larger than 2 ** 31
use_64bits_indexing |= num_kernels > (1 << 31);
if (use_64bits_indexing) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "compute_grad_input", ([&] {
deformable_col2im_kernel<scalar_t, int64_t><<<blocks, threads>>>(
num_kernels,
columns.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
mask.data_ptr<scalar_t>(),
channels,
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
n_offset_grps,
out_h,
out_w,
use_mask,
grad_im.data_ptr<scalar_t>());
}));
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "compute_grad_input", ([&] { columns.scalar_type(), "compute_grad_input", ([&] {
deformable_col2im_kernel<<<blocks, threads>>>( deformable_col2im_kernel<scalar_t, int><<<blocks, threads>>>(
num_kernels, num_kernels,
columns.data_ptr<scalar_t>(), columns.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(), offset.data_ptr<scalar_t>(),
...@@ -405,6 +477,7 @@ void compute_grad_input( ...@@ -405,6 +477,7 @@ void compute_grad_input(
use_mask, use_mask,
grad_im.data_ptr<scalar_t>()); grad_im.data_ptr<scalar_t>());
})); }));
}
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) { if (err != cudaSuccess) {
...@@ -412,18 +485,18 @@ void compute_grad_input( ...@@ -412,18 +485,18 @@ void compute_grad_input(
} }
} }
template <typename scalar_t> template <typename scalar_t, typename index_t>
__device__ scalar_t get_coordinate_weight( __device__ scalar_t get_coordinate_weight(
const scalar_t* im_data, const scalar_t* im_data,
int height, index_t height,
int width, index_t width,
scalar_t y, scalar_t y,
scalar_t x, scalar_t x,
bool is_y_direction) { bool is_y_direction) {
int y_l = floor(y); index_t y_l = floor(y);
int x_l = floor(x); index_t x_l = floor(x);
int y_h = y_l + 1; index_t y_h = y_l + 1;
int x_h = x_l + 1; index_t x_h = x_l + 1;
bool valid_y_l = 0 <= y_l && y_l < height; bool valid_y_l = 0 <= y_l && y_l < height;
bool valid_y_h = 0 <= y_h && y_h < height; bool valid_y_h = 0 <= y_h && y_h < height;
...@@ -445,47 +518,47 @@ __device__ scalar_t get_coordinate_weight( ...@@ -445,47 +518,47 @@ __device__ scalar_t get_coordinate_weight(
} }
} }
template <typename scalar_t> template <typename scalar_t, typename index_t>
__global__ void deformable_col2im_coord_kernel( __global__ void deformable_col2im_coord_kernel(
int n, index_t n,
const scalar_t* col_ptr, const scalar_t* col_ptr,
const scalar_t* im_ptr, const scalar_t* im_ptr,
const scalar_t* offset_ptr, const scalar_t* offset_ptr,
const scalar_t* mask_ptr, const scalar_t* mask_ptr,
int channels, index_t channels,
int height, index_t height,
int width, index_t width,
int weight_h, index_t weight_h,
int weight_w, index_t weight_w,
int pad_h, index_t pad_h,
int pad_w, index_t pad_w,
int stride_h, index_t stride_h,
int stride_w, index_t stride_w,
int dilation_h, index_t dilation_h,
int dilation_w, index_t dilation_w,
int batch_sz, index_t batch_sz,
int offset_channels, index_t offset_channels,
int n_offset_grps, index_t n_offset_grps,
int out_h, index_t out_h,
int out_w, index_t out_w,
const bool use_mask, const bool use_mask,
scalar_t* grad_offset, scalar_t* grad_offset,
scalar_t* grad_mask) { scalar_t* grad_mask) {
CUDA_1D_KERNEL_LOOP(index, n) { CUDA_1D_KERNEL_LOOP_T(index, n, int64_t) {
scalar_t grad_offset_val = 0; scalar_t grad_offset_val = 0;
scalar_t grad_mask_val = 0; scalar_t grad_mask_val = 0;
int w = index % out_w; index_t w = index % out_w;
int h = (index / out_w) % out_h; index_t h = (index / out_w) % out_h;
int w_w = (index / (out_w * out_h * 2)) % weight_w; index_t w_w = (index / (out_w * out_h * 2)) % weight_w;
int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; index_t w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h;
int c = (index / (out_w * out_h)) % offset_channels; index_t c = (index / (out_w * out_h)) % offset_channels;
int b = index / (out_w * out_h * offset_channels); index_t b = index / (out_w * out_h * offset_channels);
const int offset_grp = c / (2 * weight_h * weight_w); const index_t offset_grp = c / (2 * weight_h * weight_w);
const int col_step = weight_h * weight_w; const index_t col_step = weight_h * weight_w;
int c_per_offset_grp = channels / n_offset_grps; index_t c_per_offset_grp = channels / n_offset_grps;
col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz *
out_w * out_h; out_w * out_h;
...@@ -499,23 +572,24 @@ __global__ void deformable_col2im_coord_kernel( ...@@ -499,23 +572,24 @@ __global__ void deformable_col2im_coord_kernel(
out_h * out_w; out_h * out_w;
} }
const int offset_c = c - offset_grp * 2 * weight_h * weight_w; const index_t offset_c = c - offset_grp * 2 * weight_h * weight_w;
const bool is_y_direction = offset_c % 2 == 0; const bool is_y_direction = offset_c % 2 == 0;
const int c_bound = c_per_offset_grp * weight_h * weight_w; const index_t c_bound = c_per_offset_grp * weight_h * weight_w;
for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { for (index_t col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) {
const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w; const index_t col_pos =
(((col_c * batch_sz + b) * out_h) + h) * out_w + w;
int out_x = col_pos % out_w; index_t out_x = col_pos % out_w;
int out_y = (col_pos / out_w) % out_h; index_t out_y = (col_pos / out_w) % out_h;
int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; index_t j = (col_pos / (out_w * out_h * batch_sz)) % weight_w;
int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; index_t i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h;
const int mask_idx = i * weight_w + j; const index_t mask_idx = i * weight_w + j;
const int offset_h_ptr = const index_t offset_h_ptr =
(((2 * mask_idx) * out_h + out_y) * out_w + out_x); (((2 * mask_idx) * out_h + out_y) * out_w + out_x);
const int offset_w_ptr = const index_t offset_w_ptr =
(((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x);
const scalar_t offset_h = offset_ptr[offset_h_ptr]; const scalar_t offset_h = offset_ptr[offset_h_ptr];
const scalar_t offset_w = offset_ptr[offset_w_ptr]; const scalar_t offset_w = offset_ptr[offset_w_ptr];
...@@ -543,7 +617,7 @@ __global__ void deformable_col2im_coord_kernel( ...@@ -543,7 +617,7 @@ __global__ void deformable_col2im_coord_kernel(
grad_offset[index] = grad_offset_val; grad_offset[index] = grad_offset_val;
if (use_mask && is_y_direction) { if (use_mask && is_y_direction) {
const int idx = const index_t idx =
((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w +
w_w) * w_w) *
out_h + out_h +
...@@ -580,15 +654,55 @@ void compute_grad_offset_and_mask( ...@@ -580,15 +654,55 @@ void compute_grad_offset_and_mask(
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w = int out_w =
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
int num_kernels = int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w *
out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; n_offset_grps * parallel_imgs;
const unsigned int threads = GET_THREADS(); const unsigned int threads = GET_THREADS();
const unsigned int blocks = GET_BLOCKS(threads, num_kernels); const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
// Checks if we should use 64bits indexing
// https://github.com/pytorch/vision/issues/4269
bool use_64bits_indexing = false;
// Checks if columns numel is larger than 2 ** 31
use_64bits_indexing |= num_kernels > (1 << 31);
use_64bits_indexing |=
((int64_t)channels * weight_h * weight_w * parallel_imgs * out_h * out_w >
(1 << 31));
if (use_64bits_indexing) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "compute_grad_offset_and_mask", ([&] {
deformable_col2im_coord_kernel<scalar_t, int64_t>
<<<blocks, threads>>>(
num_kernels,
columns.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
mask.data_ptr<scalar_t>(),
channels,
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
2 * weight_h * weight_w * n_offset_grps,
n_offset_grps,
out_h,
out_w,
use_mask,
grad_offset.data_ptr<scalar_t>(),
grad_mask.data_ptr<scalar_t>());
}));
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "compute_grad_offset_and_mask", ([&] { columns.scalar_type(), "compute_grad_offset_and_mask", ([&] {
deformable_col2im_coord_kernel<<<blocks, threads>>>( deformable_col2im_coord_kernel<scalar_t, int><<<blocks, threads>>>(
num_kernels, num_kernels,
columns.data_ptr<scalar_t>(), columns.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
...@@ -614,6 +728,7 @@ void compute_grad_offset_and_mask( ...@@ -614,6 +728,7 @@ void compute_grad_offset_and_mask(
grad_offset.data_ptr<scalar_t>(), grad_offset.data_ptr<scalar_t>(),
grad_mask.data_ptr<scalar_t>()); grad_mask.data_ptr<scalar_t>());
})); }));
}
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) { if (err != cudaSuccess) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment