Unverified Commit 3bcc796d authored by SFMDI's avatar SFMDI Committed by GitHub
Browse files

[Fix]: Fix cuda compile error on visual studio 16.9 (#891)



* replace floor/ceil to floorf/ceilf

to eliminate cuda compilation errors in the latest version of Visual Studio 16.9

* fix lint error

* fix lint

* fix lint

* Update .pre-commit-config.yaml

* Update .pre-commit-config.yaml
Co-authored-by: default avatarwangruohui <12756472+wangruohui@users.noreply.github.com>
parent e478e9ff
...@@ -85,8 +85,8 @@ __device__ T deformable_im2col_bilinear(const T *input, const int data_width, ...@@ -85,8 +85,8 @@ __device__ T deformable_im2col_bilinear(const T *input, const int data_width,
return 0; return 0;
} }
int h_low = floor(h); int h_low = floorf(h);
int w_low = floor(w); int w_low = floorf(w);
int h_high = h_low + 1; int h_high = h_low + 1;
int w_high = w_low + 1; int w_high = w_low + 1;
...@@ -122,8 +122,8 @@ __device__ T get_gradient_weight(T argmax_h, T argmax_w, const int h, ...@@ -122,8 +122,8 @@ __device__ T get_gradient_weight(T argmax_h, T argmax_w, const int h,
return 0; return 0;
} }
int argmax_h_low = floor(argmax_h); int argmax_h_low = floorf(argmax_h);
int argmax_w_low = floor(argmax_w); int argmax_w_low = floorf(argmax_w);
int argmax_h_high = argmax_h_low + 1; int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1; int argmax_w_high = argmax_w_low + 1;
...@@ -149,8 +149,8 @@ __device__ T get_coordinate_weight(T argmax_h, T argmax_w, const int height, ...@@ -149,8 +149,8 @@ __device__ T get_coordinate_weight(T argmax_h, T argmax_w, const int height,
return 0; return 0;
} }
int argmax_h_low = floor(argmax_h); int argmax_h_low = floorf(argmax_h);
int argmax_w_low = floor(argmax_w); int argmax_w_low = floorf(argmax_w);
int argmax_h_high = argmax_h_low + 1; int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1; int argmax_w_high = argmax_w_low + 1;
......
...@@ -42,10 +42,11 @@ __global__ void deform_roi_pool_forward_cuda_kernel( ...@@ -42,10 +42,11 @@ __global__ void deform_roi_pool_forward_cuda_kernel(
int roi_bin_grid_h = int roi_bin_grid_h =
(sampling_ratio > 0) (sampling_ratio > 0)
? sampling_ratio ? sampling_ratio
: static_cast<int>(ceil(roi_height / pooled_height)); : static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_w = (sampling_ratio > 0) int roi_bin_grid_w =
? sampling_ratio (sampling_ratio > 0)
: static_cast<int>(ceil(roi_width / pooled_width)); ? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));
// Compute roi offset // Compute roi offset
if (offset != NULL) { if (offset != NULL) {
...@@ -113,10 +114,11 @@ __global__ void deform_roi_pool_backward_cuda_kernel( ...@@ -113,10 +114,11 @@ __global__ void deform_roi_pool_backward_cuda_kernel(
int roi_bin_grid_h = int roi_bin_grid_h =
(sampling_ratio > 0) (sampling_ratio > 0)
? sampling_ratio ? sampling_ratio
: static_cast<int>(ceil(roi_height / pooled_height)); : static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_w = (sampling_ratio > 0) int roi_bin_grid_w =
? sampling_ratio (sampling_ratio > 0)
: static_cast<int>(ceil(roi_width / pooled_width)); ? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));
// Compute roi offset // Compute roi offset
if (offset != NULL) { if (offset != NULL) {
......
...@@ -75,8 +75,8 @@ ...@@ -75,8 +75,8 @@
template <typename T> template <typename T>
__device__ T dmcn_im2col_bilinear(const T *input, const int data_width, __device__ T dmcn_im2col_bilinear(const T *input, const int data_width,
const int height, const int width, T h, T w) { const int height, const int width, T h, T w) {
int h_low = floor(h); int h_low = floorf(h);
int w_low = floor(w); int w_low = floorf(w);
int h_high = h_low + 1; int h_high = h_low + 1;
int w_high = w_low + 1; int w_high = w_low + 1;
...@@ -112,8 +112,8 @@ __device__ T dmcn_get_gradient_weight(T argmax_h, T argmax_w, const int h, ...@@ -112,8 +112,8 @@ __device__ T dmcn_get_gradient_weight(T argmax_h, T argmax_w, const int h,
return 0; return 0;
} }
int argmax_h_low = floor(argmax_h); int argmax_h_low = floorf(argmax_h);
int argmax_w_low = floor(argmax_w); int argmax_w_low = floorf(argmax_w);
int argmax_h_high = argmax_h_low + 1; int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1; int argmax_w_high = argmax_w_low + 1;
...@@ -140,8 +140,8 @@ __device__ T dmcn_get_coordinate_weight(T argmax_h, T argmax_w, ...@@ -140,8 +140,8 @@ __device__ T dmcn_get_coordinate_weight(T argmax_h, T argmax_w,
return 0; return 0;
} }
int argmax_h_low = floor(argmax_h); int argmax_h_low = floorf(argmax_h);
int argmax_w_low = floor(argmax_w); int argmax_w_low = floorf(argmax_w);
int argmax_h_high = argmax_h_low + 1; int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1; int argmax_w_high = argmax_w_low + 1;
......
...@@ -54,10 +54,11 @@ __global__ void roi_align_forward_cuda_kernel( ...@@ -54,10 +54,11 @@ __global__ void roi_align_forward_cuda_kernel(
int roi_bin_grid_h = int roi_bin_grid_h =
(sampling_ratio > 0) (sampling_ratio > 0)
? sampling_ratio ? sampling_ratio
: static_cast<int>(ceil(roi_height / pooled_height)); : static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_w = (sampling_ratio > 0) int roi_bin_grid_w =
? sampling_ratio (sampling_ratio > 0)
: static_cast<int>(ceil(roi_width / pooled_width)); ? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));
if (pool_mode == 0) { if (pool_mode == 0) {
// We do max pooling inside a bin // We do max pooling inside a bin
...@@ -168,11 +169,11 @@ __global__ void roi_align_backward_cuda_kernel( ...@@ -168,11 +169,11 @@ __global__ void roi_align_backward_cuda_kernel(
int roi_bin_grid_h = int roi_bin_grid_h =
(sampling_ratio > 0) (sampling_ratio > 0)
? sampling_ratio ? sampling_ratio
: static_cast<int>(ceil(roi_height / pooled_height)); : static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_w = int roi_bin_grid_w =
(sampling_ratio > 0) (sampling_ratio > 0)
? sampling_ratio ? sampling_ratio
: static_cast<int>(ceil(roi_width / pooled_width)); : static_cast<int>(ceilf(roi_width / pooled_width));
// We do average (integral) pooling inside a bin // We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
......
...@@ -36,10 +36,10 @@ __global__ void roi_pool_forward_cuda_kernel( ...@@ -36,10 +36,10 @@ __global__ void roi_pool_forward_cuda_kernel(
T bin_size_h = roi_h / static_cast<T>(pooled_height); T bin_size_h = roi_h / static_cast<T>(pooled_height);
// the corresponding bin region // the corresponding bin region
int bin_x1 = floor(static_cast<T>(pw) * bin_size_w + roi_x1); int bin_x1 = floorf(static_cast<T>(pw) * bin_size_w + roi_x1);
int bin_y1 = floor(static_cast<T>(ph) * bin_size_h + roi_y1); int bin_y1 = floorf(static_cast<T>(ph) * bin_size_h + roi_y1);
int bin_x2 = ceil(static_cast<T>(pw + 1) * bin_size_w + roi_x1); int bin_x2 = ceilf(static_cast<T>(pw + 1) * bin_size_w + roi_x1);
int bin_y2 = ceil(static_cast<T>(ph + 1) * bin_size_h + roi_y1); int bin_y2 = ceilf(static_cast<T>(ph + 1) * bin_size_h + roi_y1);
// add roi offsets and clip to input boundaries // add roi offsets and clip to input boundaries
bin_x1 = min(max(bin_x1, 0), width); bin_x1 = min(max(bin_x1, 0), width);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment