Unverified Commit f901f471 authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

[CUDA] CUDA Quantized Training (fixes #5606) (#5933)

* add quantized training (first stage)

* add histogram construction functions for integer gradients

* add stochastic rounding

* update docs

* fix compilation errors by adding template instantiations

* update files for compilation

* fix compilation of gpu version

* initialize gradient discretizer before share states

* add a test case for quantized training

* add quantized training for data distributed training

* Delete origin.pred

* Delete ifelse.pred

* Delete LightGBM_model.txt

* remove useless changes

* fix lint error

* remove debug loggings

* fix mismatch of vector and allocator types

* remove changes in main.cpp

* fix bugs with uninitialized gradient discretizer

* initialize ordered gradients in gradient discretizer

* disable quantized training with gpu and cuda

fix msvc compilation errors and warnings

* fix bug in data parallel tree learner

* make quantized training test deterministic

* make quantized training in test case more accurate

* refactor test_quantized_training

* fix leaf splits initialization with quantized training

* check distributed quantized training result

* add cuda gradient discretizer

* add quantized training for CUDA version in tree learner

* remove cuda computability 6.1 and 6.2

* fix parts of gpu quantized training errors and warnings

* fix build-python.sh to install locally built version

* fix memory access bugs

* fix lint errors

* mark cuda quantized training on cuda with categorical features as unsupported

* rename cuda_utils.h to cuda_utils.hu

* enable quantized training with cuda

* fix cuda quantized training with sparse row data

* allow using global memory buffer in histogram construction with cuda quantized training

* recover build-python.sh

enlarge allowed package size to 100M
parent 3d9ada76
...@@ -1069,6 +1069,53 @@ void CUDADataPartition::LaunchAddPredictionToScoreKernel(const double* leaf_valu ...@@ -1069,6 +1069,53 @@ void CUDADataPartition::LaunchAddPredictionToScoreKernel(const double* leaf_valu
global_timer.Stop("CUDADataPartition::AddPredictionToScoreKernel"); global_timer.Stop("CUDADataPartition::AddPredictionToScoreKernel");
} }
__global__ void RenewDiscretizedTreeLeavesKernel(
const score_t* gradients,
const score_t* hessians,
const data_size_t* data_indices,
const data_size_t* leaf_data_start,
const data_size_t* leaf_num_data,
double* leaf_grad_stat_buffer,
double* leaf_hess_stat_buffer,
double* leaf_values) {
__shared__ double shared_mem_buffer[32];
const int leaf_index = static_cast<int>(blockIdx.x);
const data_size_t* data_indices_in_leaf = data_indices + leaf_data_start[leaf_index];
const data_size_t num_data_in_leaf = leaf_num_data[leaf_index];
double sum_gradients = 0.0f;
double sum_hessians = 0.0f;
for (data_size_t inner_data_index = static_cast<int>(threadIdx.x);
inner_data_index < num_data_in_leaf; inner_data_index += static_cast<int>(blockDim.x)) {
const data_size_t data_index = data_indices_in_leaf[inner_data_index];
const score_t gradient = gradients[data_index];
const score_t hessian = hessians[data_index];
sum_gradients += static_cast<double>(gradient);
sum_hessians += static_cast<double>(hessian);
}
sum_gradients = ShuffleReduceSum<double>(sum_gradients, shared_mem_buffer, blockDim.x);
__syncthreads();
sum_hessians = ShuffleReduceSum<double>(sum_hessians, shared_mem_buffer, blockDim.x);
if (threadIdx.x == 0) {
leaf_grad_stat_buffer[leaf_index] = sum_gradients;
leaf_hess_stat_buffer[leaf_index] = sum_hessians;
}
}
void CUDADataPartition::LaunchReduceLeafGradStat(
const score_t* gradients, const score_t* hessians,
CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const {
const int num_blocks = tree->num_leaves();
RenewDiscretizedTreeLeavesKernel<<<num_blocks, FILL_INDICES_BLOCK_SIZE_DATA_PARTITION>>>(
gradients,
hessians,
cuda_data_indices_,
cuda_leaf_data_start_,
cuda_leaf_num_data_,
leaf_grad_stat_buffer,
leaf_hess_state_buffer,
tree->cuda_leaf_value_ref());
}
} // namespace LightGBM } // namespace LightGBM
#endif // USE_CUDA #endif // USE_CUDA
...@@ -78,6 +78,10 @@ class CUDADataPartition { ...@@ -78,6 +78,10 @@ class CUDADataPartition {
void ResetByLeafPred(const std::vector<int>& leaf_pred, int num_leaves); void ResetByLeafPred(const std::vector<int>& leaf_pred, int num_leaves);
void ReduceLeafGradStat(
const score_t* gradients, const score_t* hessians,
CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const;
data_size_t root_num_data() const { data_size_t root_num_data() const {
if (use_bagging_) { if (use_bagging_) {
return num_used_indices_; return num_used_indices_;
...@@ -292,6 +296,10 @@ class CUDADataPartition { ...@@ -292,6 +296,10 @@ class CUDADataPartition {
void LaunchFillDataIndexToLeafIndex(); void LaunchFillDataIndexToLeafIndex();
void LaunchReduceLeafGradStat(
const score_t* gradients, const score_t* hessians,
CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const;
// Host memory // Host memory
// dataset information // dataset information
......
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifdef USE_CUDA
#include <algorithm>
#include <LightGBM/cuda/cuda_algorithms.hpp>
#include "cuda_gradient_discretizer.hpp"
namespace LightGBM {
__global__ void ReduceMinMaxKernel(
const data_size_t num_data,
const score_t* input_gradients,
const score_t* input_hessians,
score_t* grad_min_block_buffer,
score_t* grad_max_block_buffer,
score_t* hess_min_block_buffer,
score_t* hess_max_block_buffer) {
__shared__ score_t shared_mem_buffer[32];
const data_size_t index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
score_t grad_max_val = kMinScore;
score_t grad_min_val = kMaxScore;
score_t hess_max_val = kMinScore;
score_t hess_min_val = kMaxScore;
if (index < num_data) {
grad_max_val = input_gradients[index];
grad_min_val = input_gradients[index];
hess_max_val = input_hessians[index];
hess_min_val = input_hessians[index];
}
grad_min_val = ShuffleReduceMin<score_t>(grad_min_val, shared_mem_buffer, blockDim.x);
__syncthreads();
grad_max_val = ShuffleReduceMax<score_t>(grad_max_val, shared_mem_buffer, blockDim.x);
__syncthreads();
hess_min_val = ShuffleReduceMin<score_t>(hess_min_val, shared_mem_buffer, blockDim.x);
__syncthreads();
hess_max_val = ShuffleReduceMax<score_t>(hess_max_val, shared_mem_buffer, blockDim.x);
if (threadIdx.x == 0) {
grad_min_block_buffer[blockIdx.x] = grad_min_val;
grad_max_block_buffer[blockIdx.x] = grad_max_val;
hess_min_block_buffer[blockIdx.x] = hess_min_val;
hess_max_block_buffer[blockIdx.x] = hess_max_val;
}
}
__global__ void ReduceBlockMinMaxKernel(
const int num_blocks,
const int grad_discretize_bins,
score_t* grad_min_block_buffer,
score_t* grad_max_block_buffer,
score_t* hess_min_block_buffer,
score_t* hess_max_block_buffer) {
__shared__ score_t shared_mem_buffer[32];
score_t grad_max_val = kMinScore;
score_t grad_min_val = kMaxScore;
score_t hess_max_val = kMinScore;
score_t hess_min_val = kMaxScore;
for (int block_index = static_cast<int>(threadIdx.x); block_index < num_blocks; block_index += static_cast<int>(blockDim.x)) {
grad_min_val = min(grad_min_val, grad_min_block_buffer[block_index]);
grad_max_val = max(grad_max_val, grad_max_block_buffer[block_index]);
hess_min_val = min(hess_min_val, hess_min_block_buffer[block_index]);
hess_max_val = max(hess_max_val, hess_max_block_buffer[block_index]);
}
grad_min_val = ShuffleReduceMin<score_t>(grad_min_val, shared_mem_buffer, blockDim.x);
__syncthreads();
grad_max_val = ShuffleReduceMax<score_t>(grad_max_val, shared_mem_buffer, blockDim.x);
__syncthreads();
hess_max_val = ShuffleReduceMax<score_t>(hess_max_val, shared_mem_buffer, blockDim.x);
__syncthreads();
hess_max_val = ShuffleReduceMax<score_t>(hess_max_val, shared_mem_buffer, blockDim.x);
if (threadIdx.x == 0) {
const score_t grad_abs_max = max(fabs(grad_min_val), fabs(grad_max_val));
const score_t hess_abs_max = max(fabs(hess_min_val), fabs(hess_max_val));
grad_min_block_buffer[0] = 1.0f / (grad_abs_max / (grad_discretize_bins / 2));
grad_max_block_buffer[0] = (grad_abs_max / (grad_discretize_bins / 2));
hess_min_block_buffer[0] = 1.0f / (hess_abs_max / (grad_discretize_bins));
hess_max_block_buffer[0] = (hess_abs_max / (grad_discretize_bins));
}
}
template <bool STOCHASTIC_ROUNDING>
__global__ void DiscretizeGradientsKernel(
const data_size_t num_data,
const score_t* input_gradients,
const score_t* input_hessians,
const score_t* grad_scale_ptr,
const score_t* hess_scale_ptr,
const int iter,
const int* random_values_use_start,
const score_t* gradient_random_values,
const score_t* hessian_random_values,
const int grad_discretize_bins,
int8_t* output_gradients_and_hessians) {
const int start = random_values_use_start[iter];
const data_size_t index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
const score_t grad_scale = *grad_scale_ptr;
const score_t hess_scale = *hess_scale_ptr;
int16_t* output_gradients_and_hessians_ptr = reinterpret_cast<int16_t*>(output_gradients_and_hessians);
if (index < num_data) {
if (STOCHASTIC_ROUNDING) {
const data_size_t index_offset = (index + start) % num_data;
const score_t gradient = input_gradients[index];
const score_t hessian = input_hessians[index];
const score_t gradient_random_value = gradient_random_values[index_offset];
const score_t hessian_random_value = hessian_random_values[index_offset];
output_gradients_and_hessians_ptr[2 * index + 1] = gradient > 0.0f ?
static_cast<int16_t>(gradient * grad_scale + gradient_random_value) :
static_cast<int16_t>(gradient * grad_scale - gradient_random_value);
output_gradients_and_hessians_ptr[2 * index] = static_cast<int16_t>(hessian * hess_scale + hessian_random_value);
} else {
const score_t gradient = input_gradients[index];
const score_t hessian = input_hessians[index];
output_gradients_and_hessians_ptr[2 * index + 1] = gradient > 0.0f ?
static_cast<int16_t>(gradient * grad_scale + 0.5) :
static_cast<int16_t>(gradient * grad_scale - 0.5);
output_gradients_and_hessians_ptr[2 * index] = static_cast<int16_t>(hessian * hess_scale + 0.5);
}
}
}
void CUDAGradientDiscretizer::DiscretizeGradients(
const data_size_t num_data,
const score_t* input_gradients,
const score_t* input_hessians) {
ReduceMinMaxKernel<<<num_reduce_blocks_, CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE>>>(
num_data, input_gradients, input_hessians,
grad_min_block_buffer_.RawData(),
grad_max_block_buffer_.RawData(),
hess_min_block_buffer_.RawData(),
hess_max_block_buffer_.RawData());
SynchronizeCUDADevice(__FILE__, __LINE__);
ReduceBlockMinMaxKernel<<<1, CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE>>>(
num_reduce_blocks_,
num_grad_quant_bins_,
grad_min_block_buffer_.RawData(),
grad_max_block_buffer_.RawData(),
hess_min_block_buffer_.RawData(),
hess_max_block_buffer_.RawData());
SynchronizeCUDADevice(__FILE__, __LINE__);
#define DiscretizeGradientsKernel_ARGS \
num_data, \
input_gradients, \
input_hessians, \
grad_min_block_buffer_.RawData(), \
hess_min_block_buffer_.RawData(), \
iter_, \
random_values_use_start_.RawData(), \
gradient_random_values_.RawData(), \
hessian_random_values_.RawData(), \
num_grad_quant_bins_, \
discretized_gradients_and_hessians_.RawData()
if (stochastic_rounding_) {
DiscretizeGradientsKernel<true><<<num_reduce_blocks_, CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE>>>(DiscretizeGradientsKernel_ARGS);
} else {
DiscretizeGradientsKernel<false><<<num_reduce_blocks_, CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE>>>(DiscretizeGradientsKernel_ARGS);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
++iter_;
}
} // namespace LightGBM
#endif // USE_CUDA
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifndef LIGHTGBM_TREELEARNER_CUDA_CUDA_GRADIENT_DISCRETIZER_HPP_
#define LIGHTGBM_TREELEARNER_CUDA_CUDA_GRADIENT_DISCRETIZER_HPP_
#ifdef USE_CUDA
#include <LightGBM/bin.h>
#include <LightGBM/meta.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/utils/threading.h>
#include <algorithm>
#include <random>
#include <vector>
#include "cuda_leaf_splits.hpp"
#include "../gradient_discretizer.hpp"
namespace LightGBM {
#define CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE (1024)
class CUDAGradientDiscretizer: public GradientDiscretizer {
public:
CUDAGradientDiscretizer(int num_grad_quant_bins, int num_trees, int random_seed, bool is_constant_hessian, bool stochastic_roudning):
GradientDiscretizer(num_grad_quant_bins, num_trees, random_seed, is_constant_hessian, stochastic_roudning) {
}
void DiscretizeGradients(
const data_size_t num_data,
const score_t* input_gradients,
const score_t* input_hessians) override;
const int8_t* discretized_gradients_and_hessians() const override { return discretized_gradients_and_hessians_.RawData(); }
double grad_scale() const override {
Log::Fatal("grad_scale() of CUDAGradientDiscretizer should not be called.");
return 0.0;
}
double hess_scale() const override {
Log::Fatal("hess_scale() of CUDAGradientDiscretizer should not be called.");
return 0.0;
}
const score_t* grad_scale_ptr() const { return grad_max_block_buffer_.RawData(); }
const score_t* hess_scale_ptr() const { return hess_max_block_buffer_.RawData(); }
void Init(const data_size_t num_data, const int num_leaves,
const int num_features, const Dataset* train_data) override {
GradientDiscretizer::Init(num_data, num_leaves, num_features, train_data);
discretized_gradients_and_hessians_.Resize(num_data * 2);
num_reduce_blocks_ = (num_data + CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE - 1) / CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE;
grad_min_block_buffer_.Resize(num_reduce_blocks_);
grad_max_block_buffer_.Resize(num_reduce_blocks_);
hess_min_block_buffer_.Resize(num_reduce_blocks_);
hess_max_block_buffer_.Resize(num_reduce_blocks_);
random_values_use_start_.Resize(num_trees_);
gradient_random_values_.Resize(num_data);
hessian_random_values_.Resize(num_data);
std::vector<score_t> gradient_random_values(num_data, 0.0f);
std::vector<score_t> hessian_random_values(num_data, 0.0f);
std::vector<int> random_values_use_start(num_trees_, 0);
const int num_threads = OMP_NUM_THREADS();
std::mt19937 random_values_use_start_eng = std::mt19937(random_seed_);
std::uniform_int_distribution<data_size_t> random_values_use_start_dist = std::uniform_int_distribution<data_size_t>(0, num_data);
for (int tree_index = 0; tree_index < num_trees_; ++tree_index) {
random_values_use_start[tree_index] = random_values_use_start_dist(random_values_use_start_eng);
}
int num_blocks = 0;
data_size_t block_size = 0;
Threading::BlockInfo<data_size_t>(num_data, 512, &num_blocks, &block_size);
#pragma omp parallel for schedule(static, 1) num_threads(num_threads)
for (int thread_id = 0; thread_id < num_blocks; ++thread_id) {
const data_size_t start = thread_id * block_size;
const data_size_t end = std::min(start + block_size, num_data);
std::mt19937 gradient_random_values_eng(random_seed_ + thread_id);
std::uniform_real_distribution<double> gradient_random_values_dist(0.0f, 1.0f);
std::mt19937 hessian_random_values_eng(random_seed_ + thread_id + num_threads);
std::uniform_real_distribution<double> hessian_random_values_dist(0.0f, 1.0f);
for (data_size_t i = start; i < end; ++i) {
gradient_random_values[i] = gradient_random_values_dist(gradient_random_values_eng);
hessian_random_values[i] = hessian_random_values_dist(hessian_random_values_eng);
}
}
CopyFromHostToCUDADevice<score_t>(gradient_random_values_.RawData(), gradient_random_values.data(), gradient_random_values.size(), __FILE__, __LINE__);
CopyFromHostToCUDADevice<score_t>(hessian_random_values_.RawData(), hessian_random_values.data(), hessian_random_values.size(), __FILE__, __LINE__);
CopyFromHostToCUDADevice<int>(random_values_use_start_.RawData(), random_values_use_start.data(), random_values_use_start.size(), __FILE__, __LINE__);
iter_ = 0;
}
protected:
mutable CUDAVector<int8_t> discretized_gradients_and_hessians_;
mutable CUDAVector<score_t> grad_min_block_buffer_;
mutable CUDAVector<score_t> grad_max_block_buffer_;
mutable CUDAVector<score_t> hess_min_block_buffer_;
mutable CUDAVector<score_t> hess_max_block_buffer_;
CUDAVector<int> random_values_use_start_;
CUDAVector<score_t> gradient_random_values_;
CUDAVector<score_t> hessian_random_values_;
int num_reduce_blocks_;
};
} // namespace LightGBM
#endif // USE_CUDA
#endif // LIGHTGBM_TREELEARNER_CUDA_CUDA_GRADIENT_DISCRETIZER_HPP_
...@@ -20,7 +20,9 @@ CUDAHistogramConstructor::CUDAHistogramConstructor( ...@@ -20,7 +20,9 @@ CUDAHistogramConstructor::CUDAHistogramConstructor(
const int min_data_in_leaf, const int min_data_in_leaf,
const double min_sum_hessian_in_leaf, const double min_sum_hessian_in_leaf,
const int gpu_device_id, const int gpu_device_id,
const bool gpu_use_dp): const bool gpu_use_dp,
const bool use_quantized_grad,
const int num_grad_quant_bins):
num_data_(train_data->num_data()), num_data_(train_data->num_data()),
num_features_(train_data->num_features()), num_features_(train_data->num_features()),
num_leaves_(num_leaves), num_leaves_(num_leaves),
...@@ -28,24 +30,14 @@ CUDAHistogramConstructor::CUDAHistogramConstructor( ...@@ -28,24 +30,14 @@ CUDAHistogramConstructor::CUDAHistogramConstructor(
min_data_in_leaf_(min_data_in_leaf), min_data_in_leaf_(min_data_in_leaf),
min_sum_hessian_in_leaf_(min_sum_hessian_in_leaf), min_sum_hessian_in_leaf_(min_sum_hessian_in_leaf),
gpu_device_id_(gpu_device_id), gpu_device_id_(gpu_device_id),
gpu_use_dp_(gpu_use_dp) { gpu_use_dp_(gpu_use_dp),
use_quantized_grad_(use_quantized_grad),
num_grad_quant_bins_(num_grad_quant_bins) {
InitFeatureMetaInfo(train_data, feature_hist_offsets); InitFeatureMetaInfo(train_data, feature_hist_offsets);
cuda_row_data_.reset(nullptr); cuda_row_data_.reset(nullptr);
cuda_feature_num_bins_ = nullptr;
cuda_feature_hist_offsets_ = nullptr;
cuda_feature_most_freq_bins_ = nullptr;
cuda_hist_ = nullptr;
cuda_need_fix_histogram_features_ = nullptr;
cuda_need_fix_histogram_features_num_bin_aligned_ = nullptr;
} }
CUDAHistogramConstructor::~CUDAHistogramConstructor() { CUDAHistogramConstructor::~CUDAHistogramConstructor() {
DeallocateCUDAMemory<uint32_t>(&cuda_feature_num_bins_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_feature_hist_offsets_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_feature_most_freq_bins_, __FILE__, __LINE__);
DeallocateCUDAMemory<hist_t>(&cuda_hist_, __FILE__, __LINE__);
DeallocateCUDAMemory<int>(&cuda_need_fix_histogram_features_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, __FILE__, __LINE__);
gpuAssert(cudaStreamDestroy(cuda_stream_), __FILE__, __LINE__); gpuAssert(cudaStreamDestroy(cuda_stream_), __FILE__, __LINE__);
} }
...@@ -84,54 +76,70 @@ void CUDAHistogramConstructor::InitFeatureMetaInfo(const Dataset* train_data, co ...@@ -84,54 +76,70 @@ void CUDAHistogramConstructor::InitFeatureMetaInfo(const Dataset* train_data, co
void CUDAHistogramConstructor::BeforeTrain(const score_t* gradients, const score_t* hessians) { void CUDAHistogramConstructor::BeforeTrain(const score_t* gradients, const score_t* hessians) {
cuda_gradients_ = gradients; cuda_gradients_ = gradients;
cuda_hessians_ = hessians; cuda_hessians_ = hessians;
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__); cuda_hist_.SetValue(0);
} }
void CUDAHistogramConstructor::Init(const Dataset* train_data, TrainingShareStates* share_state) { void CUDAHistogramConstructor::Init(const Dataset* train_data, TrainingShareStates* share_state) {
AllocateCUDAMemory<hist_t>(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__); cuda_hist_.Resize(static_cast<size_t>(num_total_bin_ * 2 * num_leaves_));
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__); cuda_hist_.SetValue(0);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_num_bins_, cuda_feature_num_bins_.InitFromHostVector(feature_num_bins_);
feature_num_bins_.data(), feature_num_bins_.size(), __FILE__, __LINE__); cuda_feature_hist_offsets_.InitFromHostVector(feature_hist_offsets_);
cuda_feature_most_freq_bins_.InitFromHostVector(feature_most_freq_bins_);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_hist_offsets_,
feature_hist_offsets_.data(), feature_hist_offsets_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_most_freq_bins_,
feature_most_freq_bins_.data(), feature_most_freq_bins_.size(), __FILE__, __LINE__);
cuda_row_data_.reset(new CUDARowData(train_data, share_state, gpu_device_id_, gpu_use_dp_)); cuda_row_data_.reset(new CUDARowData(train_data, share_state, gpu_device_id_, gpu_use_dp_));
cuda_row_data_->Init(train_data, share_state); cuda_row_data_->Init(train_data, share_state);
CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_stream_)); CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_stream_));
InitCUDAMemoryFromHostMemory<int>(&cuda_need_fix_histogram_features_, need_fix_histogram_features_.data(), need_fix_histogram_features_.size(), __FILE__, __LINE__); cuda_need_fix_histogram_features_.InitFromHostVector(need_fix_histogram_features_);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, need_fix_histogram_features_num_bin_aligend_.data(), cuda_need_fix_histogram_features_num_bin_aligned_.InitFromHostVector(need_fix_histogram_features_num_bin_aligend_);
need_fix_histogram_features_num_bin_aligend_.size(), __FILE__, __LINE__);
if (cuda_row_data_->NumLargeBinPartition() > 0) { if (cuda_row_data_->NumLargeBinPartition() > 0) {
int grid_dim_x = 0, grid_dim_y = 0, block_dim_x = 0, block_dim_y = 0; int grid_dim_x = 0, grid_dim_y = 0, block_dim_x = 0, block_dim_y = 0;
CalcConstructHistogramKernelDim(&grid_dim_x, &grid_dim_y, &block_dim_x, &block_dim_y, num_data_); CalcConstructHistogramKernelDim(&grid_dim_x, &grid_dim_y, &block_dim_x, &block_dim_y, num_data_);
const size_t buffer_size = static_cast<size_t>(grid_dim_y) * static_cast<size_t>(num_total_bin_) * 2; const size_t buffer_size = static_cast<size_t>(grid_dim_y) * static_cast<size_t>(num_total_bin_);
AllocateCUDAMemory<float>(&cuda_hist_buffer_, buffer_size, __FILE__, __LINE__); if (!use_quantized_grad_) {
if (gpu_use_dp_) {
// need to double the size of histogram buffer in global memory when using double precision in histogram construction
cuda_hist_buffer_.Resize(buffer_size * 4);
} else {
cuda_hist_buffer_.Resize(buffer_size * 2);
}
} else {
// use only half the size of histogram buffer in global memory when quantized training since each gradient and hessian takes only 2 bytes
cuda_hist_buffer_.Resize(buffer_size);
}
} }
hist_buffer_for_num_bit_change_.Resize(num_total_bin_ * 2);
} }
void CUDAHistogramConstructor::ConstructHistogramForLeaf( void CUDAHistogramConstructor::ConstructHistogramForLeaf(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const CUDALeafSplitsStruct* cuda_larger_leaf_splits, const CUDALeafSplitsStruct* /*cuda_larger_leaf_splits*/,
const data_size_t num_data_in_smaller_leaf, const data_size_t num_data_in_smaller_leaf,
const data_size_t num_data_in_larger_leaf, const data_size_t num_data_in_larger_leaf,
const double sum_hessians_in_smaller_leaf, const double sum_hessians_in_smaller_leaf,
const double sum_hessians_in_larger_leaf) { const double sum_hessians_in_larger_leaf,
const uint8_t num_bits_in_histogram_bins) {
if ((num_data_in_smaller_leaf <= min_data_in_leaf_ || sum_hessians_in_smaller_leaf <= min_sum_hessian_in_leaf_) && if ((num_data_in_smaller_leaf <= min_data_in_leaf_ || sum_hessians_in_smaller_leaf <= min_sum_hessian_in_leaf_) &&
(num_data_in_larger_leaf <= min_data_in_leaf_ || sum_hessians_in_larger_leaf <= min_sum_hessian_in_leaf_)) { (num_data_in_larger_leaf <= min_data_in_leaf_ || sum_hessians_in_larger_leaf <= min_sum_hessian_in_leaf_)) {
return; return;
} }
LaunchConstructHistogramKernel(cuda_smaller_leaf_splits, num_data_in_smaller_leaf); LaunchConstructHistogramKernel(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
SynchronizeCUDADevice(__FILE__, __LINE__); SynchronizeCUDADevice(__FILE__, __LINE__);
}
void CUDAHistogramConstructor::SubtractHistogramForLeaf(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
const bool use_quantized_grad,
const uint8_t parent_num_bits_in_histogram_bins,
const uint8_t smaller_num_bits_in_histogram_bins,
const uint8_t larger_num_bits_in_histogram_bins) {
global_timer.Start("CUDAHistogramConstructor::ConstructHistogramForLeaf::LaunchSubtractHistogramKernel"); global_timer.Start("CUDAHistogramConstructor::ConstructHistogramForLeaf::LaunchSubtractHistogramKernel");
LaunchSubtractHistogramKernel(cuda_smaller_leaf_splits, cuda_larger_leaf_splits); LaunchSubtractHistogramKernel(cuda_smaller_leaf_splits, cuda_larger_leaf_splits, use_quantized_grad,
parent_num_bits_in_histogram_bins, smaller_num_bits_in_histogram_bins, larger_num_bits_in_histogram_bins);
global_timer.Stop("CUDAHistogramConstructor::ConstructHistogramForLeaf::LaunchSubtractHistogramKernel"); global_timer.Stop("CUDAHistogramConstructor::ConstructHistogramForLeaf::LaunchSubtractHistogramKernel");
} }
...@@ -152,33 +160,18 @@ void CUDAHistogramConstructor::ResetTrainingData(const Dataset* train_data, Trai ...@@ -152,33 +160,18 @@ void CUDAHistogramConstructor::ResetTrainingData(const Dataset* train_data, Trai
num_data_ = train_data->num_data(); num_data_ = train_data->num_data();
num_features_ = train_data->num_features(); num_features_ = train_data->num_features();
InitFeatureMetaInfo(train_data, share_states->feature_hist_offsets()); InitFeatureMetaInfo(train_data, share_states->feature_hist_offsets());
if (feature_num_bins_.size() > 0) {
DeallocateCUDAMemory<uint32_t>(&cuda_feature_num_bins_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_feature_hist_offsets_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_feature_most_freq_bins_, __FILE__, __LINE__);
DeallocateCUDAMemory<int>(&cuda_need_fix_histogram_features_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, __FILE__, __LINE__);
DeallocateCUDAMemory<hist_t>(&cuda_hist_, __FILE__, __LINE__);
}
AllocateCUDAMemory<hist_t>(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_num_bins_,
feature_num_bins_.data(), feature_num_bins_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_hist_offsets_,
feature_hist_offsets_.data(), feature_hist_offsets_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_most_freq_bins_, cuda_hist_.Resize(static_cast<size_t>(num_total_bin_ * 2 * num_leaves_));
feature_most_freq_bins_.data(), feature_most_freq_bins_.size(), __FILE__, __LINE__); cuda_hist_.SetValue(0);
cuda_feature_num_bins_.InitFromHostVector(feature_num_bins_);
cuda_feature_hist_offsets_.InitFromHostVector(feature_hist_offsets_);
cuda_feature_most_freq_bins_.InitFromHostVector(feature_most_freq_bins_);
cuda_row_data_.reset(new CUDARowData(train_data, share_states, gpu_device_id_, gpu_use_dp_)); cuda_row_data_.reset(new CUDARowData(train_data, share_states, gpu_device_id_, gpu_use_dp_));
cuda_row_data_->Init(train_data, share_states); cuda_row_data_->Init(train_data, share_states);
InitCUDAMemoryFromHostMemory<int>(&cuda_need_fix_histogram_features_, need_fix_histogram_features_.data(), need_fix_histogram_features_.size(), __FILE__, __LINE__); cuda_need_fix_histogram_features_.InitFromHostVector(need_fix_histogram_features_);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, need_fix_histogram_features_num_bin_aligend_.data(), cuda_need_fix_histogram_features_num_bin_aligned_.InitFromHostVector(need_fix_histogram_features_num_bin_aligend_);
need_fix_histogram_features_num_bin_aligend_.size(), __FILE__, __LINE__);
} }
void CUDAHistogramConstructor::ResetConfig(const Config* config) { void CUDAHistogramConstructor::ResetConfig(const Config* config) {
...@@ -186,9 +179,8 @@ void CUDAHistogramConstructor::ResetConfig(const Config* config) { ...@@ -186,9 +179,8 @@ void CUDAHistogramConstructor::ResetConfig(const Config* config) {
num_leaves_ = config->num_leaves; num_leaves_ = config->num_leaves;
min_data_in_leaf_ = config->min_data_in_leaf; min_data_in_leaf_ = config->min_data_in_leaf;
min_sum_hessian_in_leaf_ = config->min_sum_hessian_in_leaf; min_sum_hessian_in_leaf_ = config->min_sum_hessian_in_leaf;
DeallocateCUDAMemory<hist_t>(&cuda_hist_, __FILE__, __LINE__); cuda_hist_.Resize(static_cast<size_t>(num_total_bin_ * 2 * num_leaves_));
AllocateCUDAMemory<hist_t>(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__); cuda_hist_.SetValue(0);
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
} }
} // namespace LightGBM } // namespace LightGBM
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#ifdef USE_CUDA #ifdef USE_CUDA
#include <LightGBM/cuda/cuda_row_data.hpp> #include <LightGBM/cuda/cuda_row_data.hpp>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/feature_group.h> #include <LightGBM/feature_group.h>
#include <LightGBM/tree.h> #include <LightGBM/tree.h>
...@@ -37,7 +38,9 @@ class CUDAHistogramConstructor { ...@@ -37,7 +38,9 @@ class CUDAHistogramConstructor {
const int min_data_in_leaf, const int min_data_in_leaf,
const double min_sum_hessian_in_leaf, const double min_sum_hessian_in_leaf,
const int gpu_device_id, const int gpu_device_id,
const bool gpu_use_dp); const bool gpu_use_dp,
const bool use_discretized_grad,
const int grad_discretized_bins);
~CUDAHistogramConstructor(); ~CUDAHistogramConstructor();
...@@ -49,7 +52,16 @@ class CUDAHistogramConstructor { ...@@ -49,7 +52,16 @@ class CUDAHistogramConstructor {
const data_size_t num_data_in_smaller_leaf, const data_size_t num_data_in_smaller_leaf,
const data_size_t num_data_in_larger_leaf, const data_size_t num_data_in_larger_leaf,
const double sum_hessians_in_smaller_leaf, const double sum_hessians_in_smaller_leaf,
const double sum_hessians_in_larger_leaf); const double sum_hessians_in_larger_leaf,
const uint8_t num_bits_in_histogram_bins);
void SubtractHistogramForLeaf(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
const bool use_discretized_grad,
const uint8_t parent_num_bits_in_histogram_bins,
const uint8_t smaller_num_bits_in_histogram_bins,
const uint8_t larger_num_bits_in_histogram_bins);
void ResetTrainingData(const Dataset* train_data, TrainingShareStates* share_states); void ResetTrainingData(const Dataset* train_data, TrainingShareStates* share_states);
...@@ -57,9 +69,9 @@ class CUDAHistogramConstructor { ...@@ -57,9 +69,9 @@ class CUDAHistogramConstructor {
void BeforeTrain(const score_t* gradients, const score_t* hessians); void BeforeTrain(const score_t* gradients, const score_t* hessians);
const hist_t* cuda_hist() const { return cuda_hist_; } const hist_t* cuda_hist() const { return cuda_hist_.RawData(); }
hist_t* cuda_hist_pointer() { return cuda_hist_; } hist_t* cuda_hist_pointer() { return cuda_hist_.RawData(); }
private: private:
void InitFeatureMetaInfo(const Dataset* train_data, const std::vector<uint32_t>& feature_hist_offsets); void InitFeatureMetaInfo(const Dataset* train_data, const std::vector<uint32_t>& feature_hist_offsets);
...@@ -74,30 +86,39 @@ class CUDAHistogramConstructor { ...@@ -74,30 +86,39 @@ class CUDAHistogramConstructor {
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE> template <typename HIST_TYPE, size_t SHARED_HIST_SIZE>
void LaunchConstructHistogramKernelInner( void LaunchConstructHistogramKernelInner(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf); const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins);
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE> template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE>
void LaunchConstructHistogramKernelInner0( void LaunchConstructHistogramKernelInner0(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf); const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins);
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE, typename PTR_TYPE> template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE, typename PTR_TYPE>
void LaunchConstructHistogramKernelInner1( void LaunchConstructHistogramKernelInner1(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf); const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins);
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE, typename PTR_TYPE, bool USE_GLOBAL_MEM_BUFFER> template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE, typename PTR_TYPE, bool USE_GLOBAL_MEM_BUFFER>
void LaunchConstructHistogramKernelInner2( void LaunchConstructHistogramKernelInner2(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf); const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins);
void LaunchConstructHistogramKernel( void LaunchConstructHistogramKernel(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf); const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins);
void LaunchSubtractHistogramKernel( void LaunchSubtractHistogramKernel(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const CUDALeafSplitsStruct* cuda_larger_leaf_splits); const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
const bool use_discretized_grad,
const uint8_t parent_num_bits_in_histogram_bins,
const uint8_t smaller_num_bits_in_histogram_bins,
const uint8_t larger_num_bits_in_histogram_bins);
// Host memory // Host memory
...@@ -136,19 +157,21 @@ class CUDAHistogramConstructor { ...@@ -136,19 +157,21 @@ class CUDAHistogramConstructor {
/*! \brief CUDA row wise data */ /*! \brief CUDA row wise data */
std::unique_ptr<CUDARowData> cuda_row_data_; std::unique_ptr<CUDARowData> cuda_row_data_;
/*! \brief number of bins per feature */ /*! \brief number of bins per feature */
uint32_t* cuda_feature_num_bins_; CUDAVector<uint32_t> cuda_feature_num_bins_;
/*! \brief offsets in histogram of all features */ /*! \brief offsets in histogram of all features */
uint32_t* cuda_feature_hist_offsets_; CUDAVector<uint32_t> cuda_feature_hist_offsets_;
/*! \brief most frequent bins in each feature */ /*! \brief most frequent bins in each feature */
uint32_t* cuda_feature_most_freq_bins_; CUDAVector<uint32_t> cuda_feature_most_freq_bins_;
/*! \brief CUDA histograms */ /*! \brief CUDA histograms */
hist_t* cuda_hist_; CUDAVector<hist_t> cuda_hist_;
/*! \brief CUDA histograms buffer for each block */ /*! \brief CUDA histograms buffer for each block */
float* cuda_hist_buffer_; CUDAVector<float> cuda_hist_buffer_;
/*! \brief indices of feature whose histograms need to be fixed */ /*! \brief indices of feature whose histograms need to be fixed */
int* cuda_need_fix_histogram_features_; CUDAVector<int> cuda_need_fix_histogram_features_;
/*! \brief aligned number of bins of the features whose histograms need to be fixed */ /*! \brief aligned number of bins of the features whose histograms need to be fixed */
uint32_t* cuda_need_fix_histogram_features_num_bin_aligned_; CUDAVector<uint32_t> cuda_need_fix_histogram_features_num_bin_aligned_;
/*! \brief histogram buffer used in histogram subtraction with different number of bits for histogram bins */
CUDAVector<hist_t> hist_buffer_for_num_bit_change_;
// CUDA memory, held by other object // CUDA memory, held by other object
...@@ -161,6 +184,10 @@ class CUDAHistogramConstructor { ...@@ -161,6 +184,10 @@ class CUDAHistogramConstructor {
const int gpu_device_id_; const int gpu_device_id_;
/*! \brief use double precision histogram per block */ /*! \brief use double precision histogram per block */
const bool gpu_use_dp_; const bool gpu_use_dp_;
/*! \brief whether to use quantized gradients */
const bool use_quantized_grad_;
/*! \brief the number of bins to quantized gradients */
const int num_grad_quant_bins_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -11,27 +11,22 @@ ...@@ -11,27 +11,22 @@
namespace LightGBM { namespace LightGBM {
CUDALeafSplits::CUDALeafSplits(const data_size_t num_data): CUDALeafSplits::CUDALeafSplits(const data_size_t num_data):
num_data_(num_data) { num_data_(num_data) {}
cuda_struct_ = nullptr;
cuda_sum_of_gradients_buffer_ = nullptr;
cuda_sum_of_hessians_buffer_ = nullptr;
}
CUDALeafSplits::~CUDALeafSplits() { CUDALeafSplits::~CUDALeafSplits() {}
DeallocateCUDAMemory<CUDALeafSplitsStruct>(&cuda_struct_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_sum_of_gradients_buffer_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_sum_of_hessians_buffer_, __FILE__, __LINE__);
}
void CUDALeafSplits::Init() { void CUDALeafSplits::Init(const bool use_quantized_grad) {
num_blocks_init_from_gradients_ = (num_data_ + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS; num_blocks_init_from_gradients_ = (num_data_ + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS;
// allocate more memory for sum reduction in CUDA // allocate more memory for sum reduction in CUDA
// only the first element records the final sum // only the first element records the final sum
AllocateCUDAMemory<double>(&cuda_sum_of_gradients_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__); cuda_sum_of_gradients_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
AllocateCUDAMemory<double>(&cuda_sum_of_hessians_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__); cuda_sum_of_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
if (use_quantized_grad) {
cuda_sum_of_gradients_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
}
AllocateCUDAMemory<CUDALeafSplitsStruct>(&cuda_struct_, 1, __FILE__, __LINE__); cuda_struct_.Resize(1);
} }
void CUDALeafSplits::InitValues() { void CUDALeafSplits::InitValues() {
...@@ -46,24 +41,33 @@ void CUDALeafSplits::InitValues( ...@@ -46,24 +41,33 @@ void CUDALeafSplits::InitValues(
const data_size_t num_used_indices, hist_t* cuda_hist_in_leaf, double* root_sum_hessians) { const data_size_t num_used_indices, hist_t* cuda_hist_in_leaf, double* root_sum_hessians) {
cuda_gradients_ = cuda_gradients; cuda_gradients_ = cuda_gradients;
cuda_hessians_ = cuda_hessians; cuda_hessians_ = cuda_hessians;
SetCUDAMemory<double>(cuda_sum_of_gradients_buffer_, 0, num_blocks_init_from_gradients_, __FILE__, __LINE__); cuda_sum_of_gradients_buffer_.SetValue(0);
SetCUDAMemory<double>(cuda_sum_of_hessians_buffer_, 0, num_blocks_init_from_gradients_, __FILE__, __LINE__); cuda_sum_of_hessians_buffer_.SetValue(0);
LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf); LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf);
CopyFromCUDADeviceToHost<double>(root_sum_hessians, cuda_sum_of_hessians_buffer_, 1, __FILE__, __LINE__); CopyFromCUDADeviceToHost<double>(root_sum_hessians, cuda_sum_of_hessians_buffer_.RawData(), 1, __FILE__, __LINE__);
SynchronizeCUDADevice(__FILE__, __LINE__);
}
void CUDALeafSplits::InitValues(
const double lambda_l1, const double lambda_l2,
const int16_t* cuda_gradients_and_hessians,
const data_size_t* cuda_bagging_data_indices,
const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf, double* root_sum_hessians,
const score_t* grad_scale, const score_t* hess_scale) {
cuda_gradients_ = reinterpret_cast<const score_t*>(cuda_gradients_and_hessians);
cuda_hessians_ = nullptr;
LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf, grad_scale, hess_scale);
CopyFromCUDADeviceToHost<double>(root_sum_hessians, cuda_sum_of_hessians_buffer_.RawData(), 1, __FILE__, __LINE__);
SynchronizeCUDADevice(__FILE__, __LINE__); SynchronizeCUDADevice(__FILE__, __LINE__);
} }
void CUDALeafSplits::Resize(const data_size_t num_data) { void CUDALeafSplits::Resize(const data_size_t num_data) {
if (num_data > num_data_) {
DeallocateCUDAMemory<double>(&cuda_sum_of_gradients_buffer_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_sum_of_hessians_buffer_, __FILE__, __LINE__);
num_blocks_init_from_gradients_ = (num_data + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS;
AllocateCUDAMemory<double>(&cuda_sum_of_gradients_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__);
AllocateCUDAMemory<double>(&cuda_sum_of_hessians_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__);
} else {
num_blocks_init_from_gradients_ = (num_data + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS;
}
num_data_ = num_data; num_data_ = num_data;
num_blocks_init_from_gradients_ = (num_data + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS;
cuda_sum_of_gradients_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
cuda_sum_of_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
cuda_sum_of_gradients_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
} }
} // namespace LightGBM } // namespace LightGBM
......
...@@ -81,6 +81,90 @@ __global__ void CUDAInitValuesKernel2( ...@@ -81,6 +81,90 @@ __global__ void CUDAInitValuesKernel2(
} }
} }
template <bool USE_INDICES>
__global__ void CUDAInitValuesKernel3(const int16_t* cuda_gradients_and_hessians,
const data_size_t num_data, const data_size_t* cuda_bagging_data_indices,
double* cuda_sum_of_gradients, double* cuda_sum_of_hessians, int64_t* cuda_sum_of_hessians_hessians,
const score_t* grad_scale_pointer, const score_t* hess_scale_pointer) {
const score_t grad_scale = *grad_scale_pointer;
const score_t hess_scale = *hess_scale_pointer;
__shared__ int64_t shared_mem_buffer[32];
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
int64_t int_gradient = 0;
int64_t int_hessian = 0;
if (data_index < num_data) {
int_gradient = USE_INDICES ? cuda_gradients_and_hessians[2 * cuda_bagging_data_indices[data_index] + 1] :
cuda_gradients_and_hessians[2 * data_index + 1];
int_hessian = USE_INDICES ? cuda_gradients_and_hessians[2 * cuda_bagging_data_indices[data_index]] :
cuda_gradients_and_hessians[2 * data_index];
}
const int64_t block_sum_gradient = ShuffleReduceSum<int64_t>(int_gradient, shared_mem_buffer, blockDim.x);
__syncthreads();
const int64_t block_sum_hessian = ShuffleReduceSum<int64_t>(int_hessian, shared_mem_buffer, blockDim.x);
if (threadIdx.x == 0) {
cuda_sum_of_gradients[blockIdx.x] = block_sum_gradient * grad_scale;
cuda_sum_of_hessians[blockIdx.x] = block_sum_hessian * hess_scale;
cuda_sum_of_hessians_hessians[blockIdx.x] = ((block_sum_gradient << 32) | block_sum_hessian);
}
}
__global__ void CUDAInitValuesKernel4(
const double lambda_l1,
const double lambda_l2,
const int num_blocks_to_reduce,
double* cuda_sum_of_gradients,
double* cuda_sum_of_hessians,
int64_t* cuda_sum_of_gradients_hessians,
const data_size_t num_data,
const data_size_t* cuda_data_indices_in_leaf,
hist_t* cuda_hist_in_leaf,
CUDALeafSplitsStruct* cuda_struct) {
__shared__ double shared_mem_buffer[32];
double thread_sum_of_gradients = 0.0f;
double thread_sum_of_hessians = 0.0f;
int64_t thread_sum_of_gradients_hessians = 0;
for (int block_index = static_cast<int>(threadIdx.x); block_index < num_blocks_to_reduce; block_index += static_cast<int>(blockDim.x)) {
thread_sum_of_gradients += cuda_sum_of_gradients[block_index];
thread_sum_of_hessians += cuda_sum_of_hessians[block_index];
thread_sum_of_gradients_hessians += cuda_sum_of_gradients_hessians[block_index];
}
const double sum_of_gradients = ShuffleReduceSum<double>(thread_sum_of_gradients, shared_mem_buffer, blockDim.x);
__syncthreads();
const double sum_of_hessians = ShuffleReduceSum<double>(thread_sum_of_hessians, shared_mem_buffer, blockDim.x);
__syncthreads();
const double sum_of_gradients_hessians = ShuffleReduceSum<int64_t>(
thread_sum_of_gradients_hessians,
reinterpret_cast<int64_t*>(shared_mem_buffer),
blockDim.x);
if (threadIdx.x == 0) {
cuda_sum_of_hessians[0] = sum_of_hessians;
cuda_struct->leaf_index = 0;
cuda_struct->sum_of_gradients = sum_of_gradients;
cuda_struct->sum_of_hessians = sum_of_hessians;
cuda_struct->sum_of_gradients_hessians = sum_of_gradients_hessians;
cuda_struct->num_data_in_leaf = num_data;
const bool use_l1 = lambda_l1 > 0.0f;
if (!use_l1) {
// no smoothing on root node
cuda_struct->gain = CUDALeafSplits::GetLeafGain<false, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
} else {
// no smoothing on root node
cuda_struct->gain = CUDALeafSplits::GetLeafGain<true, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
}
if (!use_l1) {
// no smoothing on root node
cuda_struct->leaf_value =
CUDALeafSplits::CalculateSplittedLeafOutput<false, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
} else {
// no smoothing on root node
cuda_struct->leaf_value =
CUDALeafSplits::CalculateSplittedLeafOutput<true, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
}
cuda_struct->data_indices_in_leaf = cuda_data_indices_in_leaf;
cuda_struct->hist_in_leaf = cuda_hist_in_leaf;
}
}
__global__ void InitValuesEmptyKernel(CUDALeafSplitsStruct* cuda_struct) { __global__ void InitValuesEmptyKernel(CUDALeafSplitsStruct* cuda_struct) {
cuda_struct->leaf_index = -1; cuda_struct->leaf_index = -1;
cuda_struct->sum_of_gradients = 0.0f; cuda_struct->sum_of_gradients = 0.0f;
...@@ -93,7 +177,7 @@ __global__ void InitValuesEmptyKernel(CUDALeafSplitsStruct* cuda_struct) { ...@@ -93,7 +177,7 @@ __global__ void InitValuesEmptyKernel(CUDALeafSplitsStruct* cuda_struct) {
} }
void CUDALeafSplits::LaunchInitValuesEmptyKernel() { void CUDALeafSplits::LaunchInitValuesEmptyKernel() {
InitValuesEmptyKernel<<<1, 1>>>(cuda_struct_); InitValuesEmptyKernel<<<1, 1>>>(cuda_struct_.RawData());
} }
void CUDALeafSplits::LaunchInitValuesKernal( void CUDALeafSplits::LaunchInitValuesKernal(
...@@ -104,23 +188,55 @@ void CUDALeafSplits::LaunchInitValuesKernal( ...@@ -104,23 +188,55 @@ void CUDALeafSplits::LaunchInitValuesKernal(
hist_t* cuda_hist_in_leaf) { hist_t* cuda_hist_in_leaf) {
if (cuda_bagging_data_indices == nullptr) { if (cuda_bagging_data_indices == nullptr) {
CUDAInitValuesKernel1<false><<<num_blocks_init_from_gradients_, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>( CUDAInitValuesKernel1<false><<<num_blocks_init_from_gradients_, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
cuda_gradients_, cuda_hessians_, num_used_indices, nullptr, cuda_sum_of_gradients_buffer_, cuda_gradients_, cuda_hessians_, num_used_indices, nullptr, cuda_sum_of_gradients_buffer_.RawData(),
cuda_sum_of_hessians_buffer_); cuda_sum_of_hessians_buffer_.RawData());
} else { } else {
CUDAInitValuesKernel1<true><<<num_blocks_init_from_gradients_, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>( CUDAInitValuesKernel1<true><<<num_blocks_init_from_gradients_, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
cuda_gradients_, cuda_hessians_, num_used_indices, cuda_bagging_data_indices, cuda_sum_of_gradients_buffer_, cuda_gradients_, cuda_hessians_, num_used_indices, cuda_bagging_data_indices, cuda_sum_of_gradients_buffer_.RawData(),
cuda_sum_of_hessians_buffer_); cuda_sum_of_hessians_buffer_.RawData());
} }
SynchronizeCUDADevice(__FILE__, __LINE__); SynchronizeCUDADevice(__FILE__, __LINE__);
CUDAInitValuesKernel2<<<1, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>( CUDAInitValuesKernel2<<<1, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
lambda_l1, lambda_l2, lambda_l1, lambda_l2,
num_blocks_init_from_gradients_, num_blocks_init_from_gradients_,
cuda_sum_of_gradients_buffer_, cuda_sum_of_gradients_buffer_.RawData(),
cuda_sum_of_hessians_buffer_, cuda_sum_of_hessians_buffer_.RawData(),
num_used_indices,
cuda_data_indices_in_leaf,
cuda_hist_in_leaf,
cuda_struct_.RawData());
SynchronizeCUDADevice(__FILE__, __LINE__);
}
void CUDALeafSplits::LaunchInitValuesKernal(
const double lambda_l1, const double lambda_l2,
const data_size_t* cuda_bagging_data_indices,
const data_size_t* cuda_data_indices_in_leaf,
const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf,
const score_t* grad_scale,
const score_t* hess_scale) {
if (cuda_bagging_data_indices == nullptr) {
CUDAInitValuesKernel3<false><<<num_blocks_init_from_gradients_, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
reinterpret_cast<const int16_t*>(cuda_gradients_), num_used_indices, nullptr, cuda_sum_of_gradients_buffer_.RawData(),
cuda_sum_of_hessians_buffer_.RawData(), cuda_sum_of_gradients_hessians_buffer_.RawData(), grad_scale, hess_scale);
} else {
CUDAInitValuesKernel3<true><<<num_blocks_init_from_gradients_, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
reinterpret_cast<const int16_t*>(cuda_gradients_), num_used_indices, cuda_bagging_data_indices, cuda_sum_of_gradients_buffer_.RawData(),
cuda_sum_of_hessians_buffer_.RawData(), cuda_sum_of_gradients_hessians_buffer_.RawData(), grad_scale, hess_scale);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
CUDAInitValuesKernel4<<<1, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
lambda_l1, lambda_l2,
num_blocks_init_from_gradients_,
cuda_sum_of_gradients_buffer_.RawData(),
cuda_sum_of_hessians_buffer_.RawData(),
cuda_sum_of_gradients_hessians_buffer_.RawData(),
num_used_indices, num_used_indices,
cuda_data_indices_in_leaf, cuda_data_indices_in_leaf,
cuda_hist_in_leaf, cuda_hist_in_leaf,
cuda_struct_); cuda_struct_.RawData());
SynchronizeCUDADevice(__FILE__, __LINE__); SynchronizeCUDADevice(__FILE__, __LINE__);
} }
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#ifdef USE_CUDA #ifdef USE_CUDA
#include <LightGBM/cuda/cuda_utils.h> #include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/bin.h> #include <LightGBM/bin.h>
#include <LightGBM/utils/log.h> #include <LightGBM/utils/log.h>
#include <LightGBM/meta.h> #include <LightGBM/meta.h>
...@@ -23,6 +23,7 @@ struct CUDALeafSplitsStruct { ...@@ -23,6 +23,7 @@ struct CUDALeafSplitsStruct {
int leaf_index; int leaf_index;
double sum_of_gradients; double sum_of_gradients;
double sum_of_hessians; double sum_of_hessians;
int64_t sum_of_gradients_hessians;
data_size_t num_data_in_leaf; data_size_t num_data_in_leaf;
double gain; double gain;
double leaf_value; double leaf_value;
...@@ -36,7 +37,7 @@ class CUDALeafSplits { ...@@ -36,7 +37,7 @@ class CUDALeafSplits {
~CUDALeafSplits(); ~CUDALeafSplits();
void Init(); void Init(const bool use_quantized_grad);
void InitValues( void InitValues(
const double lambda_l1, const double lambda_l2, const double lambda_l1, const double lambda_l2,
...@@ -45,11 +46,19 @@ class CUDALeafSplits { ...@@ -45,11 +46,19 @@ class CUDALeafSplits {
const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices, const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf, double* root_sum_hessians); hist_t* cuda_hist_in_leaf, double* root_sum_hessians);
void InitValues(
const double lambda_l1, const double lambda_l2,
const int16_t* cuda_gradients_and_hessians,
const data_size_t* cuda_bagging_data_indices,
const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf, double* root_sum_hessians,
const score_t* grad_scale, const score_t* hess_scale);
void InitValues(); void InitValues();
const CUDALeafSplitsStruct* GetCUDAStruct() const { return cuda_struct_; } const CUDALeafSplitsStruct* GetCUDAStruct() const { return cuda_struct_.RawDataReadOnly(); }
CUDALeafSplitsStruct* GetCUDAStructRef() { return cuda_struct_; } CUDALeafSplitsStruct* GetCUDAStructRef() { return cuda_struct_.RawData(); }
void Resize(const data_size_t num_data); void Resize(const data_size_t num_data);
...@@ -140,14 +149,24 @@ class CUDALeafSplits { ...@@ -140,14 +149,24 @@ class CUDALeafSplits {
const data_size_t num_used_indices, const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf); hist_t* cuda_hist_in_leaf);
void LaunchInitValuesKernal(
const double lambda_l1, const double lambda_l2,
const data_size_t* cuda_bagging_data_indices,
const data_size_t* cuda_data_indices_in_leaf,
const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf,
const score_t* grad_scale,
const score_t* hess_scale);
// Host memory // Host memory
data_size_t num_data_; data_size_t num_data_;
int num_blocks_init_from_gradients_; int num_blocks_init_from_gradients_;
// CUDA memory, held by this object // CUDA memory, held by this object
CUDALeafSplitsStruct* cuda_struct_; CUDAVector<CUDALeafSplitsStruct> cuda_struct_;
double* cuda_sum_of_gradients_buffer_; CUDAVector<double> cuda_sum_of_gradients_buffer_;
double* cuda_sum_of_hessians_buffer_; CUDAVector<double> cuda_sum_of_hessians_buffer_;
CUDAVector<int64_t> cuda_sum_of_gradients_hessians_buffer_;
// CUDA memory, held by other object // CUDA memory, held by other object
const score_t* cuda_gradients_; const score_t* cuda_gradients_;
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "cuda_single_gpu_tree_learner.hpp" #include "cuda_single_gpu_tree_learner.hpp"
#include <LightGBM/cuda/cuda_tree.hpp> #include <LightGBM/cuda/cuda_tree.hpp>
#include <LightGBM/cuda/cuda_utils.h> #include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/feature_group.h> #include <LightGBM/feature_group.h>
#include <LightGBM/network.h> #include <LightGBM/network.h>
#include <LightGBM/objective_function.h> #include <LightGBM/objective_function.h>
...@@ -39,13 +39,14 @@ void CUDASingleGPUTreeLearner::Init(const Dataset* train_data, bool is_constant_ ...@@ -39,13 +39,14 @@ void CUDASingleGPUTreeLearner::Init(const Dataset* train_data, bool is_constant_
SetCUDADevice(gpu_device_id_, __FILE__, __LINE__); SetCUDADevice(gpu_device_id_, __FILE__, __LINE__);
cuda_smaller_leaf_splits_.reset(new CUDALeafSplits(num_data_)); cuda_smaller_leaf_splits_.reset(new CUDALeafSplits(num_data_));
cuda_smaller_leaf_splits_->Init(); cuda_smaller_leaf_splits_->Init(config_->use_quantized_grad);
cuda_larger_leaf_splits_.reset(new CUDALeafSplits(num_data_)); cuda_larger_leaf_splits_.reset(new CUDALeafSplits(num_data_));
cuda_larger_leaf_splits_->Init(); cuda_larger_leaf_splits_->Init(config_->use_quantized_grad);
cuda_histogram_constructor_.reset(new CUDAHistogramConstructor(train_data_, config_->num_leaves, num_threads_, cuda_histogram_constructor_.reset(new CUDAHistogramConstructor(train_data_, config_->num_leaves, num_threads_,
share_state_->feature_hist_offsets(), share_state_->feature_hist_offsets(),
config_->min_data_in_leaf, config_->min_sum_hessian_in_leaf, gpu_device_id_, config_->gpu_use_dp)); config_->min_data_in_leaf, config_->min_sum_hessian_in_leaf, gpu_device_id_, config_->gpu_use_dp,
config_->use_quantized_grad, config_->num_grad_quant_bins));
cuda_histogram_constructor_->Init(train_data_, share_state_.get()); cuda_histogram_constructor_->Init(train_data_, share_state_.get());
const auto& feature_hist_offsets = share_state_->feature_hist_offsets(); const auto& feature_hist_offsets = share_state_->feature_hist_offsets();
...@@ -73,11 +74,19 @@ void CUDASingleGPUTreeLearner::Init(const Dataset* train_data, bool is_constant_ ...@@ -73,11 +74,19 @@ void CUDASingleGPUTreeLearner::Init(const Dataset* train_data, bool is_constant_
} }
AllocateBitset(); AllocateBitset();
cuda_leaf_gradient_stat_buffer_ = nullptr;
cuda_leaf_hessian_stat_buffer_ = nullptr;
leaf_stat_buffer_size_ = 0; leaf_stat_buffer_size_ = 0;
num_cat_threshold_ = 0; num_cat_threshold_ = 0;
if (config_->use_quantized_grad) {
cuda_leaf_gradient_stat_buffer_.Resize(config_->num_leaves);
cuda_leaf_hessian_stat_buffer_.Resize(config_->num_leaves);
cuda_gradient_discretizer_.reset(new CUDAGradientDiscretizer(
config_->num_grad_quant_bins, config_->num_iterations, config_->seed, is_constant_hessian, config_->stochastic_rounding));
cuda_gradient_discretizer_->Init(num_data_, config_->num_leaves, train_data_->num_features(), train_data_);
} else {
cuda_gradient_discretizer_.reset(nullptr);
}
#ifdef DEBUG #ifdef DEBUG
host_gradients_.resize(num_data_, 0.0f); host_gradients_.resize(num_data_, 0.0f);
host_hessians_.resize(num_data_, 0.0f); host_hessians_.resize(num_data_, 0.0f);
...@@ -101,19 +110,37 @@ void CUDASingleGPUTreeLearner::BeforeTrain() { ...@@ -101,19 +110,37 @@ void CUDASingleGPUTreeLearner::BeforeTrain() {
const data_size_t* leaf_splits_init_indices = const data_size_t* leaf_splits_init_indices =
cuda_data_partition_->use_bagging() ? cuda_data_partition_->cuda_data_indices() : nullptr; cuda_data_partition_->use_bagging() ? cuda_data_partition_->cuda_data_indices() : nullptr;
cuda_data_partition_->BeforeTrain(); cuda_data_partition_->BeforeTrain();
cuda_smaller_leaf_splits_->InitValues( if (config_->use_quantized_grad) {
config_->lambda_l1, cuda_gradient_discretizer_->DiscretizeGradients(num_data_, gradients_, hessians_);
config_->lambda_l2, cuda_histogram_constructor_->BeforeTrain(
gradients_, reinterpret_cast<const score_t*>(cuda_gradient_discretizer_->discretized_gradients_and_hessians()), nullptr);
hessians_, cuda_smaller_leaf_splits_->InitValues(
leaf_splits_init_indices, config_->lambda_l1,
cuda_data_partition_->cuda_data_indices(), config_->lambda_l2,
root_num_data, reinterpret_cast<const int16_t*>(cuda_gradient_discretizer_->discretized_gradients_and_hessians()),
cuda_histogram_constructor_->cuda_hist_pointer(), leaf_splits_init_indices,
&leaf_sum_hessians_[0]); cuda_data_partition_->cuda_data_indices(),
root_num_data,
cuda_histogram_constructor_->cuda_hist_pointer(),
&leaf_sum_hessians_[0],
cuda_gradient_discretizer_->grad_scale_ptr(),
cuda_gradient_discretizer_->hess_scale_ptr());
cuda_gradient_discretizer_->SetNumBitsInHistogramBin<false>(0, -1, root_num_data, 0);
} else {
cuda_histogram_constructor_->BeforeTrain(gradients_, hessians_);
cuda_smaller_leaf_splits_->InitValues(
config_->lambda_l1,
config_->lambda_l2,
gradients_,
hessians_,
leaf_splits_init_indices,
cuda_data_partition_->cuda_data_indices(),
root_num_data,
cuda_histogram_constructor_->cuda_hist_pointer(),
&leaf_sum_hessians_[0]);
}
leaf_num_data_[0] = root_num_data; leaf_num_data_[0] = root_num_data;
cuda_larger_leaf_splits_->InitValues(); cuda_larger_leaf_splits_->InitValues();
cuda_histogram_constructor_->BeforeTrain(gradients_, hessians_);
col_sampler_.ResetByTree(); col_sampler_.ResetByTree();
cuda_best_split_finder_->BeforeTrain(col_sampler_.is_feature_used_bytree()); cuda_best_split_finder_->BeforeTrain(col_sampler_.is_feature_used_bytree());
leaf_data_start_[0] = 0; leaf_data_start_[0] = 0;
...@@ -141,24 +168,70 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients, ...@@ -141,24 +168,70 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients,
const data_size_t num_data_in_larger_leaf = larger_leaf_index_ < 0 ? 0 : leaf_num_data_[larger_leaf_index_]; const data_size_t num_data_in_larger_leaf = larger_leaf_index_ < 0 ? 0 : leaf_num_data_[larger_leaf_index_];
const double sum_hessians_in_smaller_leaf = leaf_sum_hessians_[smaller_leaf_index_]; const double sum_hessians_in_smaller_leaf = leaf_sum_hessians_[smaller_leaf_index_];
const double sum_hessians_in_larger_leaf = larger_leaf_index_ < 0 ? 0 : leaf_sum_hessians_[larger_leaf_index_]; const double sum_hessians_in_larger_leaf = larger_leaf_index_ < 0 ? 0 : leaf_sum_hessians_[larger_leaf_index_];
const uint8_t num_bits_in_histogram_bins = config_->use_quantized_grad ? cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(smaller_leaf_index_) : 0;
cuda_histogram_constructor_->ConstructHistogramForLeaf( cuda_histogram_constructor_->ConstructHistogramForLeaf(
cuda_smaller_leaf_splits_->GetCUDAStruct(), cuda_smaller_leaf_splits_->GetCUDAStruct(),
cuda_larger_leaf_splits_->GetCUDAStruct(), cuda_larger_leaf_splits_->GetCUDAStruct(),
num_data_in_smaller_leaf, num_data_in_smaller_leaf,
num_data_in_larger_leaf, num_data_in_larger_leaf,
sum_hessians_in_smaller_leaf, sum_hessians_in_smaller_leaf,
sum_hessians_in_larger_leaf); sum_hessians_in_larger_leaf,
num_bits_in_histogram_bins);
global_timer.Stop("CUDASingleGPUTreeLearner::ConstructHistogramForLeaf"); global_timer.Stop("CUDASingleGPUTreeLearner::ConstructHistogramForLeaf");
global_timer.Start("CUDASingleGPUTreeLearner::FindBestSplitsForLeaf"); global_timer.Start("CUDASingleGPUTreeLearner::FindBestSplitsForLeaf");
SelectFeatureByNode(tree.get()); uint8_t parent_num_bits_bin = 0;
uint8_t smaller_num_bits_bin = 0;
cuda_best_split_finder_->FindBestSplitsForLeaf( uint8_t larger_num_bits_bin = 0;
if (config_->use_quantized_grad) {
if (larger_leaf_index_ != -1) {
const int parent_leaf_index = std::min(smaller_leaf_index_, larger_leaf_index_);
parent_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInNode<false>(parent_leaf_index);
smaller_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(smaller_leaf_index_);
larger_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(larger_leaf_index_);
} else {
parent_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(0);
smaller_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(0);
larger_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(0);
}
} else {
parent_num_bits_bin = 0;
smaller_num_bits_bin = 0;
larger_num_bits_bin = 0;
}
cuda_histogram_constructor_->SubtractHistogramForLeaf(
cuda_smaller_leaf_splits_->GetCUDAStruct(), cuda_smaller_leaf_splits_->GetCUDAStruct(),
cuda_larger_leaf_splits_->GetCUDAStruct(), cuda_larger_leaf_splits_->GetCUDAStruct(),
smaller_leaf_index_, larger_leaf_index_, config_->use_quantized_grad,
num_data_in_smaller_leaf, num_data_in_larger_leaf, parent_num_bits_bin,
sum_hessians_in_smaller_leaf, sum_hessians_in_larger_leaf); smaller_num_bits_bin,
larger_num_bits_bin);
SelectFeatureByNode(tree.get());
if (config_->use_quantized_grad) {
const uint8_t smaller_leaf_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(smaller_leaf_index_);
const uint8_t larger_leaf_num_bits_bin = larger_leaf_index_ < 0 ? 32 : cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(larger_leaf_index_);
cuda_best_split_finder_->FindBestSplitsForLeaf(
cuda_smaller_leaf_splits_->GetCUDAStruct(),
cuda_larger_leaf_splits_->GetCUDAStruct(),
smaller_leaf_index_, larger_leaf_index_,
num_data_in_smaller_leaf, num_data_in_larger_leaf,
sum_hessians_in_smaller_leaf, sum_hessians_in_larger_leaf,
cuda_gradient_discretizer_->grad_scale_ptr(),
cuda_gradient_discretizer_->hess_scale_ptr(),
smaller_leaf_num_bits_bin,
larger_leaf_num_bits_bin);
} else {
cuda_best_split_finder_->FindBestSplitsForLeaf(
cuda_smaller_leaf_splits_->GetCUDAStruct(),
cuda_larger_leaf_splits_->GetCUDAStruct(),
smaller_leaf_index_, larger_leaf_index_,
num_data_in_smaller_leaf, num_data_in_larger_leaf,
sum_hessians_in_smaller_leaf, sum_hessians_in_larger_leaf,
nullptr, nullptr, 0, 0);
}
global_timer.Stop("CUDASingleGPUTreeLearner::FindBestSplitsForLeaf"); global_timer.Stop("CUDASingleGPUTreeLearner::FindBestSplitsForLeaf");
global_timer.Start("CUDASingleGPUTreeLearner::FindBestFromAllSplits"); global_timer.Start("CUDASingleGPUTreeLearner::FindBestFromAllSplits");
const CUDASplitInfo* best_split_info = nullptr; const CUDASplitInfo* best_split_info = nullptr;
...@@ -247,9 +320,19 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients, ...@@ -247,9 +320,19 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients,
#endif // DEBUG #endif // DEBUG
smaller_leaf_index_ = (leaf_num_data_[best_leaf_index_] < leaf_num_data_[right_leaf_index] ? best_leaf_index_ : right_leaf_index); smaller_leaf_index_ = (leaf_num_data_[best_leaf_index_] < leaf_num_data_[right_leaf_index] ? best_leaf_index_ : right_leaf_index);
larger_leaf_index_ = (smaller_leaf_index_ == best_leaf_index_ ? right_leaf_index : best_leaf_index_); larger_leaf_index_ = (smaller_leaf_index_ == best_leaf_index_ ? right_leaf_index : best_leaf_index_);
if (config_->use_quantized_grad) {
cuda_gradient_discretizer_->SetNumBitsInHistogramBin<false>(
best_leaf_index_, right_leaf_index, leaf_num_data_[best_leaf_index_], leaf_num_data_[right_leaf_index]);
}
global_timer.Stop("CUDASingleGPUTreeLearner::Split"); global_timer.Stop("CUDASingleGPUTreeLearner::Split");
} }
SynchronizeCUDADevice(__FILE__, __LINE__); SynchronizeCUDADevice(__FILE__, __LINE__);
if (config_->use_quantized_grad && config_->quant_train_renew_leaf) {
global_timer.Start("CUDASingleGPUTreeLearner::RenewDiscretizedTreeLeaves");
RenewDiscretizedTreeLeaves(tree.get());
global_timer.Stop("CUDASingleGPUTreeLearner::RenewDiscretizedTreeLeaves");
}
tree->ToHost(); tree->ToHost();
return tree.release(); return tree.release();
} }
...@@ -357,8 +440,8 @@ void CUDASingleGPUTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFuncti ...@@ -357,8 +440,8 @@ void CUDASingleGPUTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFuncti
Tree* CUDASingleGPUTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const { Tree* CUDASingleGPUTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const {
std::unique_ptr<CUDATree> cuda_tree(new CUDATree(old_tree)); std::unique_ptr<CUDATree> cuda_tree(new CUDATree(old_tree));
SetCUDAMemory<double>(cuda_leaf_gradient_stat_buffer_, 0, static_cast<size_t>(old_tree->num_leaves()), __FILE__, __LINE__); cuda_leaf_gradient_stat_buffer_.SetValue(0);
SetCUDAMemory<double>(cuda_leaf_hessian_stat_buffer_, 0, static_cast<size_t>(old_tree->num_leaves()), __FILE__, __LINE__); cuda_leaf_hessian_stat_buffer_.SetValue(0);
ReduceLeafStat(cuda_tree.get(), gradients, hessians, cuda_data_partition_->cuda_data_indices()); ReduceLeafStat(cuda_tree.get(), gradients, hessians, cuda_data_partition_->cuda_data_indices());
cuda_tree->SyncLeafOutputFromCUDAToHost(); cuda_tree->SyncLeafOutputFromCUDAToHost();
return cuda_tree.release(); return cuda_tree.release();
...@@ -373,13 +456,9 @@ Tree* CUDASingleGPUTreeLearner::FitByExistingTree(const Tree* old_tree, const st ...@@ -373,13 +456,9 @@ Tree* CUDASingleGPUTreeLearner::FitByExistingTree(const Tree* old_tree, const st
const int num_block = (refit_num_data_ + CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE - 1) / CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE; const int num_block = (refit_num_data_ + CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE - 1) / CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE;
buffer_size *= static_cast<data_size_t>(num_block + 1); buffer_size *= static_cast<data_size_t>(num_block + 1);
} }
if (buffer_size != leaf_stat_buffer_size_) { if (static_cast<size_t>(buffer_size) > cuda_leaf_gradient_stat_buffer_.Size()) {
if (leaf_stat_buffer_size_ != 0) { cuda_leaf_gradient_stat_buffer_.Resize(buffer_size);
DeallocateCUDAMemory<double>(&cuda_leaf_gradient_stat_buffer_, __FILE__, __LINE__); cuda_leaf_hessian_stat_buffer_.Resize(buffer_size);
DeallocateCUDAMemory<double>(&cuda_leaf_hessian_stat_buffer_, __FILE__, __LINE__);
}
AllocateCUDAMemory<double>(&cuda_leaf_gradient_stat_buffer_, static_cast<size_t>(buffer_size), __FILE__, __LINE__);
AllocateCUDAMemory<double>(&cuda_leaf_hessian_stat_buffer_, static_cast<size_t>(buffer_size), __FILE__, __LINE__);
} }
return FitByExistingTree(old_tree, gradients, hessians); return FitByExistingTree(old_tree, gradients, hessians);
} }
...@@ -513,6 +592,15 @@ void CUDASingleGPUTreeLearner::CheckSplitValid( ...@@ -513,6 +592,15 @@ void CUDASingleGPUTreeLearner::CheckSplitValid(
} }
#endif // DEBUG #endif // DEBUG
void CUDASingleGPUTreeLearner::RenewDiscretizedTreeLeaves(CUDATree* cuda_tree) {
cuda_data_partition_->ReduceLeafGradStat(
gradients_, hessians_, cuda_tree,
cuda_leaf_gradient_stat_buffer_.RawData(),
cuda_leaf_hessian_stat_buffer_.RawData());
LaunchCalcLeafValuesGivenGradStat(cuda_tree, cuda_data_partition_->cuda_data_indices());
SynchronizeCUDADevice(__FILE__, __LINE__);
}
} // namespace LightGBM } // namespace LightGBM
#endif // USE_CUDA #endif // USE_CUDA
...@@ -129,18 +129,18 @@ void CUDASingleGPUTreeLearner::LaunchReduceLeafStatKernel( ...@@ -129,18 +129,18 @@ void CUDASingleGPUTreeLearner::LaunchReduceLeafStatKernel(
if (num_leaves <= 2048) { if (num_leaves <= 2048) {
ReduceLeafStatKernel_SharedMemory<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE, 2 * num_leaves * sizeof(double)>>>( ReduceLeafStatKernel_SharedMemory<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE, 2 * num_leaves * sizeof(double)>>>(
gradients, hessians, num_leaves, num_data, cuda_data_partition_->cuda_data_index_to_leaf_index(), gradients, hessians, num_leaves, num_data, cuda_data_partition_->cuda_data_index_to_leaf_index(),
cuda_leaf_gradient_stat_buffer_, cuda_leaf_hessian_stat_buffer_); cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData());
} else { } else {
ReduceLeafStatKernel_GlobalMemory<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>( ReduceLeafStatKernel_GlobalMemory<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(
gradients, hessians, num_leaves, num_data, cuda_data_partition_->cuda_data_index_to_leaf_index(), gradients, hessians, num_leaves, num_data, cuda_data_partition_->cuda_data_index_to_leaf_index(),
cuda_leaf_gradient_stat_buffer_, cuda_leaf_hessian_stat_buffer_); cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData());
} }
const bool use_l1 = config_->lambda_l1 > 0.0f; const bool use_l1 = config_->lambda_l1 > 0.0f;
const bool use_smoothing = config_->path_smooth > 0.0f; const bool use_smoothing = config_->path_smooth > 0.0f;
num_block = (num_leaves + CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE - 1) / CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE; num_block = (num_leaves + CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE - 1) / CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE;
#define CalcRefitLeafOutputKernel_ARGS \ #define CalcRefitLeafOutputKernel_ARGS \
num_leaves, cuda_leaf_gradient_stat_buffer_, cuda_leaf_hessian_stat_buffer_, num_data_in_leaf, \ num_leaves, cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData(), num_data_in_leaf, \
leaf_parent, left_child, right_child, \ leaf_parent, left_child, right_child, \
config_->lambda_l1, config_->lambda_l2, config_->path_smooth, \ config_->lambda_l1, config_->lambda_l2, config_->path_smooth, \
shrinkage_rate, config_->refit_decay_rate, cuda_leaf_value shrinkage_rate, config_->refit_decay_rate, cuda_leaf_value
...@@ -162,6 +162,7 @@ void CUDASingleGPUTreeLearner::LaunchReduceLeafStatKernel( ...@@ -162,6 +162,7 @@ void CUDASingleGPUTreeLearner::LaunchReduceLeafStatKernel(
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS); <<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
} }
} }
#undef CalcRefitLeafOutputKernel_ARGS
} }
template <typename T, bool IS_INNER> template <typename T, bool IS_INNER>
...@@ -256,6 +257,37 @@ void CUDASingleGPUTreeLearner::LaunchConstructBitsetForCategoricalSplitKernel( ...@@ -256,6 +257,37 @@ void CUDASingleGPUTreeLearner::LaunchConstructBitsetForCategoricalSplitKernel(
CUDAConstructBitset<int, false>(best_split_info, num_cat_threshold_, cuda_bitset_, cuda_bitset_len_); CUDAConstructBitset<int, false>(best_split_info, num_cat_threshold_, cuda_bitset_, cuda_bitset_len_);
} }
void CUDASingleGPUTreeLearner::LaunchCalcLeafValuesGivenGradStat(
CUDATree* cuda_tree, const data_size_t* num_data_in_leaf) {
#define CalcRefitLeafOutputKernel_ARGS \
cuda_tree->num_leaves(), cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData(), num_data_in_leaf, \
cuda_tree->cuda_leaf_parent(), cuda_tree->cuda_left_child(), cuda_tree->cuda_right_child(), \
config_->lambda_l1, config_->lambda_l2, config_->path_smooth, \
1.0f, config_->refit_decay_rate, cuda_tree->cuda_leaf_value_ref()
const bool use_l1 = config_->lambda_l1 > 0.0f;
const bool use_smoothing = config_->path_smooth > 0.0f;
const int num_block = (cuda_tree->num_leaves() + CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE - 1) / CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE;
if (!use_l1) {
if (!use_smoothing) {
CalcRefitLeafOutputKernel<false, false>
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
} else {
CalcRefitLeafOutputKernel<false, true>
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
}
} else {
if (!use_smoothing) {
CalcRefitLeafOutputKernel<true, false>
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
} else {
CalcRefitLeafOutputKernel<true, true>
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
}
}
#undef CalcRefitLeafOutputKernel_ARGS
}
} // namespace LightGBM } // namespace LightGBM
#endif // USE_CUDA #endif // USE_CUDA
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "cuda_data_partition.hpp" #include "cuda_data_partition.hpp"
#include "cuda_best_split_finder.hpp" #include "cuda_best_split_finder.hpp"
#include "cuda_gradient_discretizer.hpp"
#include "../serial_tree_learner.h" #include "../serial_tree_learner.h"
namespace LightGBM { namespace LightGBM {
...@@ -74,6 +75,10 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner { ...@@ -74,6 +75,10 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {
const double sum_left_gradients, const double sum_right_gradients); const double sum_left_gradients, const double sum_right_gradients);
#endif // DEBUG #endif // DEBUG
void RenewDiscretizedTreeLeaves(CUDATree* cuda_tree);
void LaunchCalcLeafValuesGivenGradStat(CUDATree* cuda_tree, const data_size_t* num_data_in_leaf);
// GPU device ID // GPU device ID
int gpu_device_id_; int gpu_device_id_;
// number of threads on CPU // number of threads on CPU
...@@ -90,6 +95,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner { ...@@ -90,6 +95,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {
std::unique_ptr<CUDAHistogramConstructor> cuda_histogram_constructor_; std::unique_ptr<CUDAHistogramConstructor> cuda_histogram_constructor_;
// for best split information finding, given the histograms // for best split information finding, given the histograms
std::unique_ptr<CUDABestSplitFinder> cuda_best_split_finder_; std::unique_ptr<CUDABestSplitFinder> cuda_best_split_finder_;
// gradient discretizer for quantized training
std::unique_ptr<CUDAGradientDiscretizer> cuda_gradient_discretizer_;
std::vector<int> leaf_best_split_feature_; std::vector<int> leaf_best_split_feature_;
std::vector<uint32_t> leaf_best_split_threshold_; std::vector<uint32_t> leaf_best_split_threshold_;
...@@ -108,8 +115,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner { ...@@ -108,8 +115,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {
std::vector<int> categorical_bin_to_value_; std::vector<int> categorical_bin_to_value_;
std::vector<int> categorical_bin_offsets_; std::vector<int> categorical_bin_offsets_;
mutable double* cuda_leaf_gradient_stat_buffer_; mutable CUDAVector<double> cuda_leaf_gradient_stat_buffer_;
mutable double* cuda_leaf_hessian_stat_buffer_; mutable CUDAVector<double> cuda_leaf_hessian_stat_buffer_;
mutable data_size_t leaf_stat_buffer_size_; mutable data_size_t leaf_stat_buffer_size_;
mutable data_size_t refit_num_data_; mutable data_size_t refit_num_data_;
uint32_t* cuda_bitset_; uint32_t* cuda_bitset_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment