Unverified Commit d893cd1f authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

[CUDA] Fix integer overflow in cuda row-wise data (#5167)

parent 56ccea42
...@@ -318,15 +318,15 @@ void CUDARowData::GetDenseDataPartitioned(const BIN_TYPE* row_wise_data, std::ve ...@@ -318,15 +318,15 @@ void CUDARowData::GetDenseDataPartitioned(const BIN_TYPE* row_wise_data, std::ve
[this, num_total_columns, row_wise_data, out_data] (int /*thread_index*/, data_size_t start, data_size_t end) { [this, num_total_columns, row_wise_data, out_data] (int /*thread_index*/, data_size_t start, data_size_t end) {
for (size_t i = 0; i < feature_partition_column_index_offsets_.size() - 1; ++i) { for (size_t i = 0; i < feature_partition_column_index_offsets_.size() - 1; ++i) {
const int num_prev_columns = static_cast<int>(feature_partition_column_index_offsets_[i]); const int num_prev_columns = static_cast<int>(feature_partition_column_index_offsets_[i]);
const data_size_t offset = num_data_ * num_prev_columns; const size_t offset = static_cast<size_t>(num_data_) * static_cast<size_t>(num_prev_columns);
const int partition_column_start = feature_partition_column_index_offsets_[i]; const int partition_column_start = feature_partition_column_index_offsets_[i];
const int partition_column_end = feature_partition_column_index_offsets_[i + 1]; const int partition_column_end = feature_partition_column_index_offsets_[i + 1];
const int num_columns_in_cur_partition = partition_column_end - partition_column_start; const int num_columns_in_cur_partition = partition_column_end - partition_column_start;
for (data_size_t data_index = start; data_index < end; ++data_index) { for (data_size_t data_index = start; data_index < end; ++data_index) {
const data_size_t data_offset = offset + data_index * num_columns_in_cur_partition; const size_t data_offset = offset + data_index * num_columns_in_cur_partition;
const data_size_t read_data_offset = data_index * num_total_columns; const size_t read_data_offset = static_cast<size_t>(data_index) * num_total_columns;
for (int column_index = 0; column_index < num_columns_in_cur_partition; ++column_index) { for (int column_index = 0; column_index < num_columns_in_cur_partition; ++column_index) {
const int true_column_index = read_data_offset + column_index + partition_column_start; const size_t true_column_index = read_data_offset + column_index + partition_column_start;
const BIN_TYPE bin = row_wise_data[true_column_index]; const BIN_TYPE bin = row_wise_data[true_column_index];
out_data[data_offset + column_index] = bin; out_data[data_offset + column_index] = bin;
} }
......
...@@ -32,7 +32,7 @@ __global__ void CUDAConstructHistogramDenseKernel( ...@@ -32,7 +32,7 @@ __global__ void CUDAConstructHistogramDenseKernel(
const unsigned int num_threads_per_block = blockDim.x * blockDim.y; const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
const int partition_column_start = feature_partition_column_index_offsets[blockIdx.x]; const int partition_column_start = feature_partition_column_index_offsets[blockIdx.x];
const int partition_column_end = feature_partition_column_index_offsets[blockIdx.x + 1]; const int partition_column_end = feature_partition_column_index_offsets[blockIdx.x + 1];
const BIN_TYPE* data_ptr = data + partition_column_start * num_data; const BIN_TYPE* data_ptr = data + static_cast<size_t>(partition_column_start) * num_data;
const int num_columns_in_partition = partition_column_end - partition_column_start; const int num_columns_in_partition = partition_column_end - partition_column_start;
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x]; const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1]; const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
...@@ -43,7 +43,7 @@ __global__ void CUDAConstructHistogramDenseKernel( ...@@ -43,7 +43,7 @@ __global__ void CUDAConstructHistogramDenseKernel(
} }
__syncthreads(); __syncthreads();
const unsigned int blockIdx_y = blockIdx.y; const unsigned int blockIdx_y = blockIdx.y;
const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread; const data_size_t block_start = (static_cast<size_t>(blockIdx_y) * blockDim.y) * num_data_per_thread;
const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start; const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start;
data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y))); data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y)));
const int column_index = static_cast<int>(threadIdx.x) + partition_column_start; const int column_index = static_cast<int>(threadIdx.x) + partition_column_start;
...@@ -53,7 +53,7 @@ __global__ void CUDAConstructHistogramDenseKernel( ...@@ -53,7 +53,7 @@ __global__ void CUDAConstructHistogramDenseKernel(
const data_size_t data_index = data_indices_ref_this_block[inner_data_index]; const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
const score_t grad = cuda_gradients[data_index]; const score_t grad = cuda_gradients[data_index];
const score_t hess = cuda_hessians[data_index]; const score_t hess = cuda_hessians[data_index];
const uint32_t bin = static_cast<uint32_t>(data_ptr[data_index * num_columns_in_partition + threadIdx.x]); const uint32_t bin = static_cast<uint32_t>(data_ptr[static_cast<size_t>(data_index) * num_columns_in_partition + threadIdx.x]);
const uint32_t pos = bin << 1; const uint32_t pos = bin << 1;
HIST_TYPE* pos_ptr = shared_hist_ptr + pos; HIST_TYPE* pos_ptr = shared_hist_ptr + pos;
atomicAdd_block(pos_ptr, grad); atomicAdd_block(pos_ptr, grad);
...@@ -83,7 +83,7 @@ __global__ void CUDAConstructHistogramSparseKernel( ...@@ -83,7 +83,7 @@ __global__ void CUDAConstructHistogramSparseKernel(
const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf; const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf;
__shared__ HIST_TYPE shared_hist[SHARED_HIST_SIZE]; __shared__ HIST_TYPE shared_hist[SHARED_HIST_SIZE];
const unsigned int num_threads_per_block = blockDim.x * blockDim.y; const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
const DATA_PTR_TYPE* block_row_ptr = row_ptr + blockIdx.x * (num_data + 1); const DATA_PTR_TYPE* block_row_ptr = row_ptr + static_cast<size_t>(blockIdx.x) * (num_data + 1);
const BIN_TYPE* data_ptr = data + partition_ptr[blockIdx.x]; const BIN_TYPE* data_ptr = data + partition_ptr[blockIdx.x];
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x]; const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1]; const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
...@@ -143,7 +143,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory( ...@@ -143,7 +143,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory(
const unsigned int num_threads_per_block = blockDim.x * blockDim.y; const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
const int partition_column_start = feature_partition_column_index_offsets[blockIdx.x]; const int partition_column_start = feature_partition_column_index_offsets[blockIdx.x];
const int partition_column_end = feature_partition_column_index_offsets[blockIdx.x + 1]; const int partition_column_end = feature_partition_column_index_offsets[blockIdx.x + 1];
const BIN_TYPE* data_ptr = data + partition_column_start * num_data; const BIN_TYPE* data_ptr = data + static_cast<size_t>(partition_column_start) * num_data;
const int num_columns_in_partition = partition_column_end - partition_column_start; const int num_columns_in_partition = partition_column_end - partition_column_start;
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x]; const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1]; const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
...@@ -157,7 +157,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory( ...@@ -157,7 +157,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory(
__syncthreads(); __syncthreads();
const unsigned int threadIdx_y = threadIdx.y; const unsigned int threadIdx_y = threadIdx.y;
const unsigned int blockIdx_y = blockIdx.y; const unsigned int blockIdx_y = blockIdx.y;
const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread; const data_size_t block_start = (static_cast<size_t>(blockIdx_y) * blockDim.y) * num_data_per_thread;
const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start; const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start;
data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y))); data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y)));
const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y; const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y;
...@@ -171,7 +171,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory( ...@@ -171,7 +171,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory(
const data_size_t data_index = data_indices_ref_this_block[inner_data_index]; const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
const score_t grad = cuda_gradients[data_index]; const score_t grad = cuda_gradients[data_index];
const score_t hess = cuda_hessians[data_index]; const score_t hess = cuda_hessians[data_index];
const uint32_t bin = static_cast<uint32_t>(data_ptr[data_index * num_columns_in_partition + threadIdx.x]); const uint32_t bin = static_cast<uint32_t>(data_ptr[static_cast<size_t>(data_index) * num_columns_in_partition + threadIdx.x]);
const uint32_t pos = bin << 1; const uint32_t pos = bin << 1;
float* pos_ptr = shared_hist_ptr + pos; float* pos_ptr = shared_hist_ptr + pos;
atomicAdd_block(pos_ptr, grad); atomicAdd_block(pos_ptr, grad);
...@@ -202,7 +202,7 @@ __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory( ...@@ -202,7 +202,7 @@ __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory(
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y; const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf; const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf;
const unsigned int num_threads_per_block = blockDim.x * blockDim.y; const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
const DATA_PTR_TYPE* block_row_ptr = row_ptr + blockIdx.x * (num_data + 1); const DATA_PTR_TYPE* block_row_ptr = row_ptr + static_cast<size_t>(blockIdx.x) * (num_data + 1);
const BIN_TYPE* data_ptr = data + partition_ptr[blockIdx.x]; const BIN_TYPE* data_ptr = data + partition_ptr[blockIdx.x];
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x]; const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1]; const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment