Unverified Commit 45c53f78 authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

[CUDA] Add Huber regression objective for cuda_exp (#5462)

* add huber regression for cuda_exp

* renew tree output on GPU

add test cases for regression objectives

* remove useless changes

* add white space

* fix test_regression
parent 7d1276ad
...@@ -92,7 +92,7 @@ class TreeLearner { ...@@ -92,7 +92,7 @@ class TreeLearner {
virtual void AddPredictionToScore(const Tree* tree, double* out_score) const = 0; virtual void AddPredictionToScore(const Tree* tree, double* out_score) const = 0;
virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter, virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const = 0; data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt, const double* train_score) const = 0;
TreeLearner() = default; TreeLearner() = default;
/*! \brief Disable copy */ /*! \brief Disable copy */
......
...@@ -500,7 +500,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) { ...@@ -500,7 +500,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
auto score_ptr = train_score_updater_->score() + offset; auto score_ptr = train_score_updater_->score() + offset;
auto residual_getter = [score_ptr](const label_t* label, int i) {return static_cast<double>(label[i]) - score_ptr[i]; }; auto residual_getter = [score_ptr](const label_t* label, int i) {return static_cast<double>(label[i]) - score_ptr[i]; };
tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, residual_getter, tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, residual_getter,
num_data_, bag_data_indices_.data(), bag_data_cnt_); num_data_, bag_data_indices_.data(), bag_data_cnt_, train_score_updater_->score());
// shrinkage by learning rate // shrinkage by learning rate
new_tree->Shrinkage(shrinkage_rate_); new_tree->Shrinkage(shrinkage_rate_);
// update score // update score
......
...@@ -132,7 +132,7 @@ class RF : public GBDT { ...@@ -132,7 +132,7 @@ class RF : public GBDT {
double pred = init_scores_[cur_tree_id]; double pred = init_scores_[cur_tree_id];
auto residual_getter = [pred](const label_t* label, int i) {return static_cast<double>(label[i]) - pred; }; auto residual_getter = [pred](const label_t* label, int i) {return static_cast<double>(label[i]) - pred; };
tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, residual_getter, tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, residual_getter,
num_data_, bag_data_indices_.data(), bag_data_cnt_); num_data_, bag_data_indices_.data(), bag_data_cnt_, train_score_updater_->score());
if (std::fabs(init_scores_[cur_tree_id]) > kEpsilon) { if (std::fabs(init_scores_[cur_tree_id]) > kEpsilon) {
new_tree->AddBias(init_scores_[cur_tree_id]); new_tree->AddBias(init_scores_[cur_tree_id]);
} }
......
...@@ -81,6 +81,20 @@ void CUDARegressionL1loss::RenewTreeOutputCUDA( ...@@ -81,6 +81,20 @@ void CUDARegressionL1loss::RenewTreeOutputCUDA(
global_timer.Stop("CUDARegressionL1loss::LaunchRenewTreeOutputCUDAKernel"); global_timer.Stop("CUDARegressionL1loss::LaunchRenewTreeOutputCUDAKernel");
} }
CUDARegressionHuberLoss::CUDARegressionHuberLoss(const Config& config):
CUDARegressionL2loss(config), alpha_(config.alpha) {
if (sqrt_) {
Log::Warning("Cannot use sqrt transform in %s Regression, will auto disable it", GetName());
sqrt_ = false;
}
}
CUDARegressionHuberLoss::CUDARegressionHuberLoss(const std::vector<std::string>& strs):
CUDARegressionL2loss(strs) {}
CUDARegressionHuberLoss::~CUDARegressionHuberLoss() {}
} // namespace LightGBM } // namespace LightGBM
#endif // USE_CUDA_EXP #endif // USE_CUDA_EXP
...@@ -189,6 +189,43 @@ void CUDARegressionL1loss::LaunchRenewTreeOutputCUDAKernel( ...@@ -189,6 +189,43 @@ void CUDARegressionL1loss::LaunchRenewTreeOutputCUDAKernel(
} }
template <bool USE_WEIGHT>
__global__ void GetGradientsKernel_Huber(const double* cuda_scores, const label_t* cuda_labels, const label_t* cuda_weights, const data_size_t num_data,
const double alpha, score_t* cuda_out_gradients, score_t* cuda_out_hessians) {
const data_size_t data_index = static_cast<data_size_t>(blockDim.x * blockIdx.x + threadIdx.x);
if (data_index < num_data) {
if (!USE_WEIGHT) {
const double diff = cuda_scores[data_index] - static_cast<double>(cuda_labels[data_index]);
if (fabs(diff) <= alpha) {
cuda_out_gradients[data_index] = static_cast<score_t>(diff);
} else {
const score_t sign = static_cast<score_t>((diff > 0.0f) - (diff < 0.0f));
cuda_out_gradients[data_index] = static_cast<score_t>(sign * alpha);
}
cuda_out_hessians[data_index] = 1.0f;
} else {
const double diff = cuda_scores[data_index] - static_cast<double>(cuda_labels[data_index]);
const score_t weight = static_cast<score_t>(cuda_weights[data_index]);
if (fabs(diff) <= alpha) {
cuda_out_gradients[data_index] = static_cast<score_t>(diff) * weight;
} else {
const score_t sign = static_cast<score_t>((diff > 0.0f) - (diff < 0.0f));
cuda_out_gradients[data_index] = static_cast<score_t>(sign * alpha) * weight;
}
cuda_out_hessians[data_index] = weight;
}
}
}
void CUDARegressionHuberLoss::LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const {
const int num_blocks = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_REGRESSION - 1) / GET_GRADIENTS_BLOCK_SIZE_REGRESSION;
if (cuda_weights_ == nullptr) {
GetGradientsKernel_Huber<false><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_REGRESSION>>>(score, cuda_labels_, nullptr, num_data_, alpha_, gradients, hessians);
} else {
GetGradientsKernel_Huber<true><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_REGRESSION>>>(score, cuda_labels_, cuda_weights_, num_data_, alpha_, gradients, hessians);
}
}
} // namespace LightGBM } // namespace LightGBM
#endif // USE_CUDA_EXP #endif // USE_CUDA_EXP
...@@ -101,6 +101,23 @@ class CUDARegressionL1loss : public CUDARegressionL2loss { ...@@ -101,6 +101,23 @@ class CUDARegressionL1loss : public CUDARegressionL2loss {
}; };
class CUDARegressionHuberLoss : public CUDARegressionL2loss {
public:
explicit CUDARegressionHuberLoss(const Config& config);
explicit CUDARegressionHuberLoss(const std::vector<std::string>& strs);
~CUDARegressionHuberLoss();
bool IsRenewTreeOutput() const override { return true; }
private:
void LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const override;
const double alpha_ = 0.0f;
};
} // namespace LightGBM } // namespace LightGBM
#endif // USE_CUDA_EXP #endif // USE_CUDA_EXP
......
...@@ -300,21 +300,27 @@ void CUDASingleGPUTreeLearner::SetBaggingData(const Dataset* /*subset*/, ...@@ -300,21 +300,27 @@ void CUDASingleGPUTreeLearner::SetBaggingData(const Dataset* /*subset*/,
} }
void CUDASingleGPUTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter, void CUDASingleGPUTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const { data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt, const double* train_score) const {
CHECK(tree->is_cuda_tree()); CHECK(tree->is_cuda_tree());
CUDATree* cuda_tree = reinterpret_cast<CUDATree*>(tree); CUDATree* cuda_tree = reinterpret_cast<CUDATree*>(tree);
if (obj != nullptr && obj->IsRenewTreeOutput()) { if (obj != nullptr && obj->IsRenewTreeOutput()) {
CHECK_LE(cuda_tree->num_leaves(), data_partition_->num_leaves()); CHECK_LE(cuda_tree->num_leaves(), data_partition_->num_leaves());
if (boosting_on_cuda_) {
obj->RenewTreeOutputCUDA(train_score, cuda_data_partition_->cuda_data_indices(),
cuda_data_partition_->cuda_leaf_num_data(), cuda_data_partition_->cuda_leaf_data_start(),
cuda_tree->num_leaves(), cuda_tree->cuda_leaf_value_ref());
cuda_tree->SyncLeafOutputFromCUDAToHost();
} else {
const data_size_t* bag_mapper = nullptr; const data_size_t* bag_mapper = nullptr;
if (total_num_data != num_data_) { if (total_num_data != num_data_) {
CHECK_EQ(bag_cnt, num_data_); CHECK_EQ(bag_cnt, num_data_);
bag_mapper = bag_indices; bag_mapper = bag_indices;
} }
std::vector<int> n_nozeroworker_perleaf(tree->num_leaves(), 1); std::vector<int> n_nozeroworker_perleaf(cuda_tree->num_leaves(), 1);
int num_machines = Network::num_machines(); int num_machines = Network::num_machines();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < tree->num_leaves(); ++i) { for (int i = 0; i < cuda_tree->num_leaves(); ++i) {
const double output = static_cast<double>(tree->LeafOutput(i)); const double output = static_cast<double>(cuda_tree->LeafOutput(i));
data_size_t cnt_leaf_data = leaf_num_data_[i]; data_size_t cnt_leaf_data = leaf_num_data_[i];
std::vector<data_size_t> index_mapper(cnt_leaf_data, -1); std::vector<data_size_t> index_mapper(cnt_leaf_data, -1);
CopyFromCUDADeviceToHost<data_size_t>(index_mapper.data(), CopyFromCUDADeviceToHost<data_size_t>(index_mapper.data(),
...@@ -322,26 +328,27 @@ void CUDASingleGPUTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFuncti ...@@ -322,26 +328,27 @@ void CUDASingleGPUTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFuncti
static_cast<size_t>(cnt_leaf_data), __FILE__, __LINE__); static_cast<size_t>(cnt_leaf_data), __FILE__, __LINE__);
if (cnt_leaf_data > 0) { if (cnt_leaf_data > 0) {
const double new_output = obj->RenewTreeOutput(output, residual_getter, index_mapper.data(), bag_mapper, cnt_leaf_data); const double new_output = obj->RenewTreeOutput(output, residual_getter, index_mapper.data(), bag_mapper, cnt_leaf_data);
tree->SetLeafOutput(i, new_output); cuda_tree->SetLeafOutput(i, new_output);
} else { } else {
CHECK_GT(num_machines, 1); CHECK_GT(num_machines, 1);
tree->SetLeafOutput(i, 0.0); cuda_tree->SetLeafOutput(i, 0.0);
n_nozeroworker_perleaf[i] = 0; n_nozeroworker_perleaf[i] = 0;
} }
} }
if (num_machines > 1) { if (num_machines > 1) {
std::vector<double> outputs(tree->num_leaves()); std::vector<double> outputs(cuda_tree->num_leaves());
for (int i = 0; i < tree->num_leaves(); ++i) { for (int i = 0; i < cuda_tree->num_leaves(); ++i) {
outputs[i] = static_cast<double>(tree->LeafOutput(i)); outputs[i] = static_cast<double>(cuda_tree->LeafOutput(i));
} }
outputs = Network::GlobalSum(&outputs); outputs = Network::GlobalSum(&outputs);
n_nozeroworker_perleaf = Network::GlobalSum(&n_nozeroworker_perleaf); n_nozeroworker_perleaf = Network::GlobalSum(&n_nozeroworker_perleaf);
for (int i = 0; i < tree->num_leaves(); ++i) { for (int i = 0; i < cuda_tree->num_leaves(); ++i) {
tree->SetLeafOutput(i, outputs[i] / n_nozeroworker_perleaf[i]); cuda_tree->SetLeafOutput(i, outputs[i] / n_nozeroworker_perleaf[i]);
} }
} }
} }
cuda_tree->SyncLeafOutputFromHostToCUDA(); cuda_tree->SyncLeafOutputFromHostToCUDA();
}
} }
Tree* CUDASingleGPUTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const { Tree* CUDASingleGPUTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const {
......
...@@ -40,7 +40,7 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner { ...@@ -40,7 +40,7 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {
void AddPredictionToScore(const Tree* tree, double* out_score) const override; void AddPredictionToScore(const Tree* tree, double* out_score) const override;
void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter, void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const override; data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt, const double* train_score) const override;
void ResetConfig(const Config* config) override; void ResetConfig(const Config* config) override;
......
...@@ -719,7 +719,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf, ...@@ -719,7 +719,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
} }
void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter, void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const { data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt, const double* /*train_score*/) const {
if (obj != nullptr && obj->IsRenewTreeOutput()) { if (obj != nullptr && obj->IsRenewTreeOutput()) {
CHECK_LE(tree->num_leaves(), data_partition_->num_leaves()); CHECK_LE(tree->num_leaves(), data_partition_->num_leaves());
const data_size_t* bag_mapper = nullptr; const data_size_t* bag_mapper = nullptr;
......
...@@ -114,7 +114,7 @@ class SerialTreeLearner: public TreeLearner { ...@@ -114,7 +114,7 @@ class SerialTreeLearner: public TreeLearner {
} }
void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter, void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const override; data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt, const double* train_score) const override;
/*! \brief Get output of parent node, used for path smoothing */ /*! \brief Get output of parent node, used for path smoothing */
double GetParentOutput(const Tree* tree, const LeafSplits* leaf_splits) const; double GetParentOutput(const Tree* tree, const LeafSplits* leaf_splits) const;
......
...@@ -111,10 +111,12 @@ def test_rf(): ...@@ -111,10 +111,12 @@ def test_rf():
assert evals_result['valid_0']['binary_logloss'][-1] == pytest.approx(ret) assert evals_result['valid_0']['binary_logloss'][-1] == pytest.approx(ret)
def test_regression(): @pytest.mark.parametrize('objective', ['regression', 'regression_l1', 'huber'])
def test_regression(objective):
X, y = load_boston(return_X_y=True) X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
params = { params = {
'objective': objective,
'metric': 'l2', 'metric': 'l2',
'verbose': -1 'verbose': -1
} }
...@@ -129,6 +131,9 @@ def test_regression(): ...@@ -129,6 +131,9 @@ def test_regression():
callbacks=[lgb.record_evaluation(evals_result)] callbacks=[lgb.record_evaluation(evals_result)]
) )
ret = mean_squared_error(y_test, gbm.predict(X_test)) ret = mean_squared_error(y_test, gbm.predict(X_test))
if objective == 'huber':
assert ret < 35
else:
assert ret < 7 assert ret < 7
assert evals_result['valid_0']['l2'][-1] == pytest.approx(ret) assert evals_result['valid_0']['l2'][-1] == pytest.approx(ret)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment