Unverified Commit 417c732c authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

[CUDA] Fix row-wise histogram construction with dense data matrix (#5103)

* fix cuda exp with dense row wise

* disable usage of multi val group in cuda exp
parent 60e72d5f
...@@ -342,7 +342,7 @@ void Dataset::Construct(std::vector<std::unique_ptr<BinMapper>>* bin_mappers, ...@@ -342,7 +342,7 @@ void Dataset::Construct(std::vector<std::unique_ptr<BinMapper>>* bin_mappers,
auto is_sparse = io_config.is_enable_sparse; auto is_sparse = io_config.is_enable_sparse;
if (io_config.device_type == std::string("cuda") || io_config.device_type == std::string("cuda_exp")) { if (io_config.device_type == std::string("cuda") || io_config.device_type == std::string("cuda_exp")) {
LGBM_config_::current_device = lgbm_device_cuda; LGBM_config_::current_device = lgbm_device_cuda;
if (io_config.device_type == std::string("cuda") && is_sparse) { if ((io_config.device_type == std::string("cuda") || io_config.device_type == std::string("cuda_exp")) && is_sparse) {
Log::Warning("Using sparse features with CUDA is currently not supported."); Log::Warning("Using sparse features with CUDA is currently not supported.");
is_sparse = false; is_sparse = false;
} }
......
...@@ -284,7 +284,11 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner0( ...@@ -284,7 +284,11 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner0(
} else if (cuda_row_data_->row_ptr_bit_type() == 64) { } else if (cuda_row_data_->row_ptr_bit_type() == 64) {
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint64_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf); LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint64_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
} else { } else {
Log::Fatal("Unknown row_ptr_bit_type = %d", cuda_row_data_->row_ptr_bit_type()); if (!cuda_row_data_->is_sparse()) {
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
} else {
Log::Fatal("Unknown row_ptr_bit_type = %d", cuda_row_data_->row_ptr_bit_type());
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment