Unverified Commit f860d0ed authored by 007gzs's avatar 007gzs Committed by GitHub
Browse files

[Fix] Fix floor/ceil error in deform_attn for windows (#1037)



* windows build doc

* fix floor ceil

* fix lint

* change floor/ceil to floorf/ceilf

* Update build.md

* recover onnx and parrots

* fix clang-lint
Co-authored-by: default avatarWRH <12756472+wangruohui@users.noreply.github.com>
parent faf6c6cd
...@@ -24,8 +24,8 @@ __device__ scalar_t ms_deform_attn_im2col_bilinear( ...@@ -24,8 +24,8 @@ __device__ scalar_t ms_deform_attn_im2col_bilinear(
const scalar_t *&bottom_data, const int &height, const int &width, const scalar_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &channels, const scalar_t &h, const int &nheads, const int &channels, const scalar_t &h,
const scalar_t &w, const int &m, const int &c) { const scalar_t &w, const int &m, const int &c) {
const int h_low = floor(h); const int h_low = floorf(h);
const int w_low = floor(w); const int w_low = floorf(w);
const int h_high = h_low + 1; const int h_high = h_low + 1;
const int w_high = w_low + 1; const int w_high = w_low + 1;
...@@ -75,8 +75,8 @@ __device__ void ms_deform_attn_col2im_bilinear( ...@@ -75,8 +75,8 @@ __device__ void ms_deform_attn_col2im_bilinear(
const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad,
const scalar_t &attn_weight, scalar_t *&grad_value, const scalar_t &attn_weight, scalar_t *&grad_value,
scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) {
const int h_low = floor(h); const int h_low = floorf(h);
const int w_low = floor(w); const int w_low = floorf(w);
const int h_high = h_low + 1; const int h_high = h_low + 1;
const int w_high = w_low + 1; const int w_high = w_low + 1;
...@@ -142,8 +142,8 @@ __device__ void ms_deform_attn_col2im_bilinear_gm( ...@@ -142,8 +142,8 @@ __device__ void ms_deform_attn_col2im_bilinear_gm(
const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad,
const scalar_t &attn_weight, scalar_t *&grad_value, const scalar_t &attn_weight, scalar_t *&grad_value,
scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) {
const int h_low = floor(h); const int h_low = floorf(h);
const int w_low = floor(w); const int w_low = floorf(w);
const int h_high = h_low + 1; const int h_high = h_low + 1;
const int w_high = w_low + 1; const int w_high = w_low + 1;
......
...@@ -146,9 +146,9 @@ void ROIAlignForward(const int nthreads, const T* input, const T* rois, ...@@ -146,9 +146,9 @@ void ROIAlignForward(const int nthreads, const T* input, const T* rois,
// We use roi_bin_grid to sample the grid and mimic integral // We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0) int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio ? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2 : ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
// When the grid is empty, output zeros == 0/1, instead of NaN. // When the grid is empty, output zeros == 0/1, instead of NaN.
const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
...@@ -337,12 +337,13 @@ void ROIAlignBackward(const int nthreads, const T* grad_output, const T* rois, ...@@ -337,12 +337,13 @@ void ROIAlignBackward(const int nthreads, const T* grad_output, const T* rois,
} else if (pool_mode == 1) { } else if (pool_mode == 1) {
// We do average (integral) pooling inside a bin // We do average (integral) pooling inside a bin
// We use roi_bin_grid to sample the grid and mimic integral // We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0) int roi_bin_grid_h =
? sampling_ratio (sampling_ratio > 0)
: ceil(roi_height / pooled_height); // e.g., = 2 ? sampling_ratio
: ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = (sampling_ratio > 0) int roi_bin_grid_w = (sampling_ratio > 0)
? sampling_ratio ? sampling_ratio
: ceil(roi_width / pooled_width); : ceilf(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
for (int iy = 0; iy < roi_bin_grid_h; iy++) { for (int iy = 0; iy < roi_bin_grid_h; iy++) {
......
...@@ -156,9 +156,9 @@ void ROIAlignRotatedForward(const int nthreads, const T* input, ...@@ -156,9 +156,9 @@ void ROIAlignRotatedForward(const int nthreads, const T* input,
// We use roi_bin_grid to sample the grid and mimic integral // We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0) int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio ? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2 : ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
// We do average (integral) pooling inside a bin // We do average (integral) pooling inside a bin
const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
...@@ -322,9 +322,9 @@ void ROIAlignRotatedBackward( ...@@ -322,9 +322,9 @@ void ROIAlignRotatedBackward(
// We use roi_bin_grid to sample the grid and mimic integral // We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0) int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio ? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2 : ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after. // Appropriate translation needs to be applied after.
......
...@@ -60,9 +60,9 @@ __global__ void roi_align_rotated_forward_cuda_kernel( ...@@ -60,9 +60,9 @@ __global__ void roi_align_rotated_forward_cuda_kernel(
// We use roi_bin_grid to sample the grid and mimic integral // We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sample_num > 0) int roi_bin_grid_h = (sample_num > 0)
? sample_num ? sample_num
: ceil(roi_height / pooled_height); // e.g., = 2 : ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = int roi_bin_grid_w =
(sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); (sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after. // Appropriate translation needs to be applied after.
...@@ -148,9 +148,9 @@ __global__ void roi_align_rotated_backward_cuda_kernel( ...@@ -148,9 +148,9 @@ __global__ void roi_align_rotated_backward_cuda_kernel(
// We use roi_bin_grid to sample the grid and mimic integral // We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sample_num > 0) int roi_bin_grid_h = (sample_num > 0)
? sample_num ? sample_num
: ceil(roi_height / pooled_height); // e.g., = 2 : ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = int roi_bin_grid_w =
(sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); (sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
// Appropriate translation needs to be applied after. // Appropriate translation needs to be applied after.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment