Unverified Commit 3670e476 authored by CharlesAuguste's avatar CharlesAuguste Committed by GitHub
Browse files

Refactoring monotone constraints (linked to #2305) (#2717)



* Move monotone constraints to the monotone_constraints files.

* Add checks for debug mode.

* Refactored FindBestSplitsFromHistograms.

* Add headers.

* fix

* Update data_parallel_tree_learner.cpp

* simplify ComputeBestSplitForFeature

* Fix min / max issue.

* Remove duplicated check.
Co-authored-by: default avatarGuolin Ke <guolin.ke@outlook.com>
parent d8a34df9
...@@ -562,6 +562,9 @@ std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_ty ...@@ -562,6 +562,9 @@ std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_ty
for (int iter = 0; iter < num_used_model; ++iter) { for (int iter = 0; iter < num_used_model; ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) { for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
if (models_[iter]->split_gain(split_idx) > 0) { if (models_[iter]->split_gain(split_idx) > 0) {
#ifdef DEBUG
CHECK(models_[iter]->split_feature(split_idx) >= 0);
#endif
feature_importances[models_[iter]->split_feature(split_idx)] += 1.0; feature_importances[models_[iter]->split_feature(split_idx)] += 1.0;
} }
} }
...@@ -570,6 +573,9 @@ std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_ty ...@@ -570,6 +573,9 @@ std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_ty
for (int iter = 0; iter < num_used_model; ++iter) { for (int iter = 0; iter < num_used_model; ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) { for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
if (models_[iter]->split_gain(split_idx) > 0) { if (models_[iter]->split_gain(split_idx) > 0) {
#ifdef DEBUG
CHECK(models_[iter]->split_feature(split_idx) >= 0);
#endif
feature_importances[models_[iter]->split_feature(split_idx)] += models_[iter]->split_gain(split_idx); feature_importances[models_[iter]->split_feature(split_idx)] += models_[iter]->split_gain(split_idx);
} }
} }
......
...@@ -187,67 +187,55 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const ...@@ -187,67 +187,55 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
this->train_data_->FixHistogram(feature_index, this->train_data_->FixHistogram(feature_index,
this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(), this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(),
this->smaller_leaf_histogram_array_[feature_index].RawData()); this->smaller_leaf_histogram_array_[feature_index].RawData());
SplitInfo smaller_split;
// find best threshold for smaller child this->ComputeBestSplitForFeature(
this->smaller_leaf_histogram_array_[feature_index].FindBestThreshold( this->smaller_leaf_histogram_array_, feature_index, real_feature_index,
this->smaller_leaf_splits_->sum_gradients(), smaller_node_used_features[feature_index],
this->smaller_leaf_splits_->sum_hessians(), GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->leaf_index()),
GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->LeafIndex()), this->smaller_leaf_splits_.get(),
this->smaller_leaf_splits_->min_constraint(), &smaller_bests_per_thread[tid]);
this->smaller_leaf_splits_->max_constraint(),
&smaller_split);
smaller_split.feature = real_feature_index;
if (smaller_split > smaller_bests_per_thread[tid] && smaller_node_used_features[feature_index]) {
smaller_bests_per_thread[tid] = smaller_split;
}
// only root leaf // only root leaf
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->LeafIndex() < 0) continue; if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->leaf_index() < 0) continue;
// construct histgroms for large leaf, we init larger leaf as the parent, so we can just subtract the smaller leaf's histograms // construct histgroms for large leaf, we init larger leaf as the parent, so we can just subtract the smaller leaf's histograms
this->larger_leaf_histogram_array_[feature_index].Subtract( this->larger_leaf_histogram_array_[feature_index].Subtract(
this->smaller_leaf_histogram_array_[feature_index]); this->smaller_leaf_histogram_array_[feature_index]);
SplitInfo larger_split;
// find best threshold for larger child this->ComputeBestSplitForFeature(
this->larger_leaf_histogram_array_[feature_index].FindBestThreshold( this->larger_leaf_histogram_array_, feature_index, real_feature_index,
this->larger_leaf_splits_->sum_gradients(), larger_node_used_features[feature_index],
this->larger_leaf_splits_->sum_hessians(), GetGlobalDataCountInLeaf(this->larger_leaf_splits_->leaf_index()),
GetGlobalDataCountInLeaf(this->larger_leaf_splits_->LeafIndex()), this->larger_leaf_splits_.get(),
this->larger_leaf_splits_->min_constraint(), &larger_bests_per_thread[tid]);
this->larger_leaf_splits_->max_constraint(),
&larger_split);
larger_split.feature = real_feature_index;
if (larger_split > larger_bests_per_thread[tid] && larger_node_used_features[feature_index]) {
larger_bests_per_thread[tid] = larger_split;
}
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_bests_per_thread); auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_bests_per_thread);
int leaf = this->smaller_leaf_splits_->LeafIndex(); int leaf = this->smaller_leaf_splits_->leaf_index();
this->best_split_per_leaf_[leaf] = smaller_bests_per_thread[smaller_best_idx]; this->best_split_per_leaf_[leaf] = smaller_bests_per_thread[smaller_best_idx];
if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->leaf_index() >= 0) {
leaf = this->larger_leaf_splits_->LeafIndex(); leaf = this->larger_leaf_splits_->leaf_index();
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_bests_per_thread); auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_bests_per_thread);
this->best_split_per_leaf_[leaf] = larger_bests_per_thread[larger_best_idx]; this->best_split_per_leaf_[leaf] = larger_bests_per_thread[larger_best_idx];
} }
SplitInfo smaller_best_split, larger_best_split; SplitInfo smaller_best_split, larger_best_split;
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()]; smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()];
// find local best split for larger leaf // find local best split for larger leaf
if (this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_->leaf_index() >= 0) {
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()]; larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()];
} }
// sync global best info // sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold); SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
// set best split // set best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split; this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()] = smaller_best_split;
if (this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_->leaf_index() >= 0) {
this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()] = larger_best_split; this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()] = larger_best_split;
} }
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <vector> #include <vector>
#include "split_info.hpp" #include "split_info.hpp"
#include "monotone_constraints.hpp"
namespace LightGBM { namespace LightGBM {
...@@ -58,11 +59,11 @@ class FeatureHistogram { ...@@ -58,11 +59,11 @@ class FeatureHistogram {
meta_ = meta; meta_ = meta;
data_ = data; data_ = data;
if (meta_->bin_type == BinType::NumericalBin) { if (meta_->bin_type == BinType::NumericalBin) {
find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdNumerical, this, std::placeholders::_1, find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdNumerical, this, std::placeholders::_1
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6); , std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5);
} else { } else {
find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdCategorical, this, std::placeholders::_1, find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdCategorical, this, std::placeholders::_1
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6); , std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5);
} }
rand_ = Random(meta_->config->extra_seed); rand_ = Random(meta_->config->extra_seed);
} }
...@@ -81,15 +82,15 @@ class FeatureHistogram { ...@@ -81,15 +82,15 @@ class FeatureHistogram {
} }
void FindBestThreshold(double sum_gradient, double sum_hessian, data_size_t num_data, void FindBestThreshold(double sum_gradient, double sum_hessian, data_size_t num_data,
double min_constraint, double max_constraint, SplitInfo* output) { const ConstraintEntry& constraints, SplitInfo* output) {
output->default_left = true; output->default_left = true;
output->gain = kMinScore; output->gain = kMinScore;
find_best_threshold_fun_(sum_gradient, sum_hessian + 2 * kEpsilon, num_data, min_constraint, max_constraint, output); find_best_threshold_fun_(sum_gradient, sum_hessian + 2 * kEpsilon, num_data, constraints, output);
output->gain *= meta_->penalty; output->gain *= meta_->penalty;
} }
void FindBestThresholdNumerical(double sum_gradient, double sum_hessian, data_size_t num_data, void FindBestThresholdNumerical(double sum_gradient, double sum_hessian, data_size_t num_data,
double min_constraint, double max_constraint, SplitInfo* output) { const ConstraintEntry& constraints, SplitInfo* output) {
is_splittable_ = false; is_splittable_ = false;
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian, double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step); meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step);
...@@ -102,26 +103,26 @@ class FeatureHistogram { ...@@ -102,26 +103,26 @@ class FeatureHistogram {
if (meta_->num_bin > 2 && meta_->missing_type != MissingType::None) { if (meta_->num_bin > 2 && meta_->missing_type != MissingType::None) {
if (meta_->missing_type == MissingType::Zero) { if (meta_->missing_type == MissingType::Zero) {
if (is_rand) { if (is_rand) {
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, true, false, rand_threshold); FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, -1, true, false, rand_threshold);
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, true, false, rand_threshold); FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, 1, true, false, rand_threshold);
} else { } else {
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, true, false, rand_threshold); FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, -1, true, false, rand_threshold);
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, true, false, rand_threshold); FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, 1, true, false, rand_threshold);
} }
} else { } else {
if (is_rand) { if (is_rand) {
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, true, rand_threshold); FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, -1, false, true, rand_threshold);
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, false, true, rand_threshold); FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, 1, false, true, rand_threshold);
} else { } else {
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, true, rand_threshold); FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, -1, false, true, rand_threshold);
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, false, true, rand_threshold); FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, 1, false, true, rand_threshold);
} }
} }
} else { } else {
if (is_rand) { if (is_rand) {
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, false, rand_threshold); FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, -1, false, false, rand_threshold);
} else { } else {
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, false, rand_threshold); FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, -1, false, false, rand_threshold);
} }
// fix the direction error when only have 2 bins // fix the direction error when only have 2 bins
if (meta_->missing_type == MissingType::NaN) { if (meta_->missing_type == MissingType::NaN) {
...@@ -130,12 +131,10 @@ class FeatureHistogram { ...@@ -130,12 +131,10 @@ class FeatureHistogram {
} }
output->gain -= min_gain_shift; output->gain -= min_gain_shift;
output->monotone_type = meta_->monotone_type; output->monotone_type = meta_->monotone_type;
output->min_constraint = min_constraint;
output->max_constraint = max_constraint;
} }
void FindBestThresholdCategorical(double sum_gradient, double sum_hessian, data_size_t num_data, void FindBestThresholdCategorical(double sum_gradient, double sum_hessian, data_size_t num_data,
double min_constraint, double max_constraint, SplitInfo* output) { const ConstraintEntry& constraints, SplitInfo* output) {
output->default_left = false; output->default_left = false;
double best_gain = kMinScore; double best_gain = kMinScore;
data_size_t best_left_count = 0; data_size_t best_left_count = 0;
...@@ -173,7 +172,7 @@ class FeatureHistogram { ...@@ -173,7 +172,7 @@ class FeatureHistogram {
double sum_other_gradient = sum_gradient - grad; double sum_other_gradient = sum_gradient - grad;
// current split gain // current split gain
double current_gain = GetSplitGains(sum_other_gradient, sum_other_hessian, grad, hess + kEpsilon, double current_gain = GetSplitGains(sum_other_gradient, sum_other_hessian, grad, hess + kEpsilon,
meta_->config->lambda_l1, l2, meta_->config->max_delta_step, min_constraint, max_constraint, 0); meta_->config->lambda_l1, l2, meta_->config->max_delta_step, constraints, 0);
// gain with split is worse than without split // gain with split is worse than without split
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
...@@ -253,7 +252,7 @@ class FeatureHistogram { ...@@ -253,7 +252,7 @@ class FeatureHistogram {
double sum_right_gradient = sum_gradient - sum_left_gradient; double sum_right_gradient = sum_gradient - sum_left_gradient;
if (!meta_->config->extra_trees || i == rand_threshold) { if (!meta_->config->extra_trees || i == rand_threshold) {
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian, double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->config->lambda_l1, l2, meta_->config->max_delta_step, min_constraint, max_constraint, 0); meta_->config->lambda_l1, l2, meta_->config->max_delta_step, constraints, 0);
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
is_splittable_ = true; is_splittable_ = true;
if (current_gain > best_gain) { if (current_gain > best_gain) {
...@@ -271,13 +270,13 @@ class FeatureHistogram { ...@@ -271,13 +270,13 @@ class FeatureHistogram {
if (is_splittable_) { if (is_splittable_) {
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian, output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian,
meta_->config->lambda_l1, l2, meta_->config->max_delta_step, min_constraint, max_constraint); meta_->config->lambda_l1, l2, meta_->config->max_delta_step, constraints);
output->left_count = best_left_count; output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient; output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian - kEpsilon; output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput( output->right_output = CalculateSplittedLeafOutput(
sum_gradient - best_sum_left_gradient, sum_hessian - best_sum_left_hessian, sum_gradient - best_sum_left_gradient, sum_hessian - best_sum_left_hessian,
meta_->config->lambda_l1, l2, meta_->config->max_delta_step, min_constraint, max_constraint); meta_->config->lambda_l1, l2, meta_->config->max_delta_step, constraints);
output->right_count = num_data - best_left_count; output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient; output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon; output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon;
...@@ -301,8 +300,6 @@ class FeatureHistogram { ...@@ -301,8 +300,6 @@ class FeatureHistogram {
} }
} }
output->monotone_type = 0; output->monotone_type = 0;
output->min_constraint = min_constraint;
output->max_constraint = max_constraint;
} }
} }
...@@ -481,9 +478,9 @@ class FeatureHistogram { ...@@ -481,9 +478,9 @@ class FeatureHistogram {
static double GetSplitGains(double sum_left_gradients, double sum_left_hessians, static double GetSplitGains(double sum_left_gradients, double sum_left_hessians,
double sum_right_gradients, double sum_right_hessians, double sum_right_gradients, double sum_right_hessians,
double l1, double l2, double max_delta_step, double l1, double l2, double max_delta_step,
double min_constraint, double max_constraint, int8_t monotone_constraint) { const ConstraintEntry& constraints, int8_t monotone_constraint) {
double left_output = CalculateSplittedLeafOutput(sum_left_gradients, sum_left_hessians, l1, l2, max_delta_step, min_constraint, max_constraint); double left_output = CalculateSplittedLeafOutput(sum_left_gradients, sum_left_hessians, l1, l2, max_delta_step, constraints);
double right_output = CalculateSplittedLeafOutput(sum_right_gradients, sum_right_hessians, l1, l2, max_delta_step, min_constraint, max_constraint); double right_output = CalculateSplittedLeafOutput(sum_right_gradients, sum_right_hessians, l1, l2, max_delta_step, constraints);
if (((monotone_constraint > 0) && (left_output > right_output)) || if (((monotone_constraint > 0) && (left_output > right_output)) ||
((monotone_constraint < 0) && (left_output < right_output))) { ((monotone_constraint < 0) && (left_output < right_output))) {
return 0; return 0;
...@@ -498,13 +495,14 @@ class FeatureHistogram { ...@@ -498,13 +495,14 @@ class FeatureHistogram {
* \param sum_hessians * \param sum_hessians
* \return leaf output * \return leaf output
*/ */
static double CalculateSplittedLeafOutput(double sum_gradients, double sum_hessians, double l1, double l2, double max_delta_step, static double
double min_constraint, double max_constraint) { CalculateSplittedLeafOutput(double sum_gradients, double sum_hessians,
double l1, double l2, double max_delta_step, const ConstraintEntry& constraints) {
double ret = CalculateSplittedLeafOutput(sum_gradients, sum_hessians, l1, l2, max_delta_step); double ret = CalculateSplittedLeafOutput(sum_gradients, sum_hessians, l1, l2, max_delta_step);
if (ret < min_constraint) { if (ret < constraints.min) {
ret = min_constraint; ret = constraints.min;
} else if (ret > max_constraint) { } else if (ret > constraints.max) {
ret = max_constraint; ret = constraints.max;
} }
return ret; return ret;
} }
...@@ -526,7 +524,7 @@ class FeatureHistogram { ...@@ -526,7 +524,7 @@ class FeatureHistogram {
} }
template<bool is_rand> template<bool is_rand>
void FindBestThresholdSequence(double sum_gradient, double sum_hessian, data_size_t num_data, double min_constraint, double max_constraint, void FindBestThresholdSequence(double sum_gradient, double sum_hessian, data_size_t num_data, const ConstraintEntry& constraints,
double min_gain_shift, SplitInfo* output, int dir, bool skip_default_bin, bool use_na_as_missing, int rand_threshold) { double min_gain_shift, SplitInfo* output, int dir, bool skip_default_bin, bool use_na_as_missing, int rand_threshold) {
const int8_t offset = meta_->offset; const int8_t offset = meta_->offset;
...@@ -571,7 +569,7 @@ class FeatureHistogram { ...@@ -571,7 +569,7 @@ class FeatureHistogram {
// current split gain // current split gain
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian, double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step, meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint, meta_->monotone_type); constraints, meta_->monotone_type);
// gain with split is worse than without split // gain with split is worse than without split
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
...@@ -635,7 +633,7 @@ class FeatureHistogram { ...@@ -635,7 +633,7 @@ class FeatureHistogram {
// current split gain // current split gain
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian, double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step, meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint, meta_->monotone_type); constraints, meta_->monotone_type);
// gain with split is worse than without split // gain with split is worse than without split
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
...@@ -656,16 +654,17 @@ class FeatureHistogram { ...@@ -656,16 +654,17 @@ class FeatureHistogram {
if (is_splittable_ && best_gain > output->gain) { if (is_splittable_ && best_gain > output->gain) {
// update split information // update split information
output->threshold = best_threshold; output->threshold = best_threshold;
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian, output->left_output = CalculateSplittedLeafOutput(
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step, best_sum_left_gradient, best_sum_left_hessian,
min_constraint, max_constraint); meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->config->max_delta_step, constraints);
output->left_count = best_left_count; output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient; output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian - kEpsilon; output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput( output->right_output = CalculateSplittedLeafOutput(
sum_gradient - best_sum_left_gradient, sum_hessian - best_sum_left_hessian, sum_gradient - best_sum_left_gradient, sum_hessian - best_sum_left_hessian,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step, meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint); constraints);
output->right_count = num_data - best_left_count; output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient; output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon; output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon;
...@@ -681,7 +680,9 @@ class FeatureHistogram { ...@@ -681,7 +680,9 @@ class FeatureHistogram {
/*! \brief random number generator for extremely randomized trees */ /*! \brief random number generator for extremely randomized trees */
Random rand_; Random rand_;
std::function<void(double, double, data_size_t, double, double, SplitInfo*)> find_best_threshold_fun_; std::function<void(double, double, data_size_t, const ConstraintEntry&,
SplitInfo*)>
find_best_threshold_fun_;
}; };
class HistogramPool { class HistogramPool {
public: public:
......
...@@ -56,17 +56,17 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(con ...@@ -56,17 +56,17 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(con
TREELEARNER_T::FindBestSplitsFromHistograms(is_feature_used, use_subtract); TREELEARNER_T::FindBestSplitsFromHistograms(is_feature_used, use_subtract);
SplitInfo smaller_best_split, larger_best_split; SplitInfo smaller_best_split, larger_best_split;
// get best split at smaller leaf // get best split at smaller leaf
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()]; smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()];
// find local best split for larger leaf // find local best split for larger leaf
if (this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_->leaf_index() >= 0) {
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()]; larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()];
} }
// sync global best info // sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold); SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
// update best split // update best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split; this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()] = smaller_best_split;
if (this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_->leaf_index() >= 0) {
this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()] = larger_best_split; this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()] = larger_best_split;
} }
} }
......
...@@ -1089,14 +1089,8 @@ void GPUTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right ...@@ -1089,14 +1089,8 @@ void GPUTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right
Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf()); Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf());
} }
} else { } else {
double smaller_min = smaller_leaf_splits_->min_constraint();
double smaller_max = smaller_leaf_splits_->max_constraint();
double larger_min = larger_leaf_splits_->min_constraint();
double larger_max = larger_leaf_splits_->max_constraint();
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian); smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian); larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
smaller_leaf_splits_->SetValueConstraint(smaller_min, smaller_max);
larger_leaf_splits_->SetValueConstraint(larger_min, larger_max);
if ((best_split_info.left_count != larger_leaf_splits_->num_data_in_leaf()) || if ((best_split_info.left_count != larger_leaf_splits_->num_data_in_leaf()) ||
(best_split_info.right_count!= smaller_leaf_splits_->num_data_in_leaf())) { (best_split_info.right_count!= smaller_leaf_splits_->num_data_in_leaf())) {
Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf()); Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf());
......
...@@ -43,16 +43,8 @@ class LeafSplits { ...@@ -43,16 +43,8 @@ class LeafSplits {
data_indices_ = data_partition->GetIndexOnLeaf(leaf, &num_data_in_leaf_); data_indices_ = data_partition->GetIndexOnLeaf(leaf, &num_data_in_leaf_);
sum_gradients_ = sum_gradients; sum_gradients_ = sum_gradients;
sum_hessians_ = sum_hessians; sum_hessians_ = sum_hessians;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
} }
void SetValueConstraint(double min, double max) {
min_val_ = min;
max_val_ = max;
}
/*! /*!
* \brief Init splits on current leaf, it will traverse all data to sum up the results * \brief Init splits on current leaf, it will traverse all data to sum up the results
* \param gradients * \param gradients
...@@ -71,8 +63,6 @@ class LeafSplits { ...@@ -71,8 +63,6 @@ class LeafSplits {
} }
sum_gradients_ = tmp_sum_gradients; sum_gradients_ = tmp_sum_gradients;
sum_hessians_ = tmp_sum_hessians; sum_hessians_ = tmp_sum_hessians;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
} }
/*! /*!
...@@ -95,8 +85,6 @@ class LeafSplits { ...@@ -95,8 +85,6 @@ class LeafSplits {
} }
sum_gradients_ = tmp_sum_gradients; sum_gradients_ = tmp_sum_gradients;
sum_hessians_ = tmp_sum_hessians; sum_hessians_ = tmp_sum_hessians;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
} }
...@@ -109,8 +97,6 @@ class LeafSplits { ...@@ -109,8 +97,6 @@ class LeafSplits {
leaf_index_ = 0; leaf_index_ = 0;
sum_gradients_ = sum_gradients; sum_gradients_ = sum_gradients;
sum_hessians_ = sum_hessians; sum_hessians_ = sum_hessians;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
} }
/*! /*!
...@@ -120,13 +106,11 @@ class LeafSplits { ...@@ -120,13 +106,11 @@ class LeafSplits {
leaf_index_ = -1; leaf_index_ = -1;
data_indices_ = nullptr; data_indices_ = nullptr;
num_data_in_leaf_ = 0; num_data_in_leaf_ = 0;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
} }
/*! \brief Get current leaf index */ /*! \brief Get current leaf index */
int LeafIndex() const { return leaf_index_; } int leaf_index() const { return leaf_index_; }
/*! \brief Get numer of data in current leaf */ /*! \brief Get numer of data in current leaf */
data_size_t num_data_in_leaf() const { return num_data_in_leaf_; } data_size_t num_data_in_leaf() const { return num_data_in_leaf_; }
...@@ -137,10 +121,6 @@ class LeafSplits { ...@@ -137,10 +121,6 @@ class LeafSplits {
/*! \brief Get sum of hessians of current leaf */ /*! \brief Get sum of hessians of current leaf */
double sum_hessians() const { return sum_hessians_; } double sum_hessians() const { return sum_hessians_; }
double max_constraint() const { return max_val_; }
double min_constraint() const { return min_val_; }
/*! \brief Get indices of data of current leaf */ /*! \brief Get indices of data of current leaf */
const data_size_t* data_indices() const { return data_indices_; } const data_size_t* data_indices() const { return data_indices_; }
...@@ -158,8 +138,6 @@ class LeafSplits { ...@@ -158,8 +138,6 @@ class LeafSplits {
double sum_hessians_; double sum_hessians_;
/*! \brief indices of data of current leaf */ /*! \brief indices of data of current leaf */
const data_size_t* data_indices_; const data_size_t* data_indices_;
double min_val_;
double max_val_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
#ifndef LIGHTGBM_TREELEARNER_MONOTONE_CONSTRAINTS_H_
#define LIGHTGBM_TREELEARNER_MONOTONE_CONSTRAINTS_H_
#include <algorithm>
#include <vector>
#include <cstdint>
#include <limits>
namespace LightGBM {
struct ConstraintEntry {
double min = -std::numeric_limits<double>::max();
double max = std::numeric_limits<double>::max();
ConstraintEntry(){};
void Reset() {
min = -std::numeric_limits<double>::max();
max = std::numeric_limits<double>::max();
}
void UpdateMin(double new_min) { min = std::max(new_min, min); }
void UpdateMax(double new_max) { max = std::min(new_max, max); }
};
template <typename ConstraintEntry>
class LeafConstraints {
public:
LeafConstraints(int num_leaves) : num_leaves_(num_leaves) {
entries_.resize(num_leaves_);
}
void Reset() {
for (auto& entry : entries_) {
entry.Reset();
}
}
void UpdateConstraints(bool is_numerical_split, int leaf, int new_leaf,
int8_t monotone_type, double right_output,
double left_output) {
entries_[new_leaf] = entries_[leaf];
if (is_numerical_split) {
double mid = (left_output + right_output) / 2.0f;
if (monotone_type < 0) {
entries_[leaf].UpdateMin(mid);
entries_[new_leaf].UpdateMax(mid);
} else if (monotone_type > 0) {
entries_[leaf].UpdateMax(mid);
entries_[new_leaf].UpdateMin(mid);
}
}
}
const ConstraintEntry& Get(int leaf_idx) const { return entries_[leaf_idx]; }
private:
int num_leaves_;
std::vector<ConstraintEntry> entries_;
};
} // namespace LightGBM
#endif // LightGBM_TREELEARNER_MONOTONE_CONSTRAINTS_H_
...@@ -53,6 +53,7 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian ...@@ -53,6 +53,7 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian
// push split information for all leaves // push split information for all leaves
best_split_per_leaf_.resize(config_->num_leaves); best_split_per_leaf_.resize(config_->num_leaves);
constraints_.reset(new LeafConstraints<ConstraintEntry>(config_->num_leaves));
// initialize splits for leaf // initialize splits for leaf
smaller_leaf_splits_.reset(new LeafSplits(train_data_->num_data())); smaller_leaf_splits_.reset(new LeafSplits(train_data_->num_data()));
...@@ -291,6 +292,8 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -291,6 +292,8 @@ void SerialTreeLearner::BeforeTrain() {
// initialize data partition // initialize data partition
data_partition_->Init(); data_partition_->Init();
constraints_->Reset();
// reset the splits for leaves // reset the splits for leaves
for (int i = 0; i < config_->num_leaves; ++i) { for (int i = 0; i < config_->num_leaves; ++i) {
best_split_per_leaf_[i].Reset(); best_split_per_leaf_[i].Reset();
...@@ -408,27 +411,19 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>& ...@@ -408,27 +411,19 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>&
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
if (!is_feature_used[feature_index]) { continue; } if (!is_feature_used[feature_index]) { continue; }
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
SplitInfo smaller_split;
train_data_->FixHistogram(feature_index, train_data_->FixHistogram(feature_index,
smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians(), smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians(),
smaller_leaf_histogram_array_[feature_index].RawData()); smaller_leaf_histogram_array_[feature_index].RawData());
int real_fidx = train_data_->RealFeatureIndex(feature_index); int real_fidx = train_data_->RealFeatureIndex(feature_index);
smaller_leaf_histogram_array_[feature_index].FindBestThreshold(
smaller_leaf_splits_->sum_gradients(), ComputeBestSplitForFeature(smaller_leaf_histogram_array_, feature_index,
smaller_leaf_splits_->sum_hessians(), real_fidx,
smaller_node_used_features[feature_index],
smaller_leaf_splits_->num_data_in_leaf(), smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_->min_constraint(), smaller_leaf_splits_.get(), &smaller_best[tid]);
smaller_leaf_splits_->max_constraint(),
&smaller_split);
smaller_split.feature = real_fidx;
if (cegb_ != nullptr) {
smaller_split.gain -= cegb_->DetlaGain(feature_index, real_fidx, smaller_leaf_splits_->LeafIndex(), smaller_leaf_splits_->num_data_in_leaf(), smaller_split);
}
if (smaller_split > smaller_best[tid] && smaller_node_used_features[feature_index]) {
smaller_best[tid] = smaller_split;
}
// only has root leaf // only has root leaf
if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) { continue; } if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->leaf_index() < 0) { continue; }
if (use_subtract) { if (use_subtract) {
larger_leaf_histogram_array_[feature_index].Subtract(smaller_leaf_histogram_array_[feature_index]); larger_leaf_histogram_array_[feature_index].Subtract(smaller_leaf_histogram_array_[feature_index]);
...@@ -436,31 +431,23 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>& ...@@ -436,31 +431,23 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>&
train_data_->FixHistogram(feature_index, larger_leaf_splits_->sum_gradients(), larger_leaf_splits_->sum_hessians(), train_data_->FixHistogram(feature_index, larger_leaf_splits_->sum_gradients(), larger_leaf_splits_->sum_hessians(),
larger_leaf_histogram_array_[feature_index].RawData()); larger_leaf_histogram_array_[feature_index].RawData());
} }
SplitInfo larger_split;
// find best threshold for larger child ComputeBestSplitForFeature(larger_leaf_histogram_array_, feature_index,
larger_leaf_histogram_array_[feature_index].FindBestThreshold( real_fidx,
larger_leaf_splits_->sum_gradients(), larger_node_used_features[feature_index],
larger_leaf_splits_->sum_hessians(),
larger_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf(),
larger_leaf_splits_->min_constraint(), larger_leaf_splits_.get(),
larger_leaf_splits_->max_constraint(), &larger_best[tid]);
&larger_split);
larger_split.feature = real_fidx;
if (cegb_ != nullptr) {
larger_split.gain -= cegb_->DetlaGain(feature_index, real_fidx, larger_leaf_splits_->LeafIndex(), larger_leaf_splits_->num_data_in_leaf(), larger_split);
}
if (larger_split > larger_best[tid] && larger_node_used_features[feature_index]) {
larger_best[tid] = larger_split;
}
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_best); auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_best);
int leaf = smaller_leaf_splits_->LeafIndex(); int leaf = smaller_leaf_splits_->leaf_index();
best_split_per_leaf_[leaf] = smaller_best[smaller_best_idx]; best_split_per_leaf_[leaf] = smaller_best[smaller_best_idx];
if (larger_leaf_splits_ != nullptr && larger_leaf_splits_->LeafIndex() >= 0) { if (larger_leaf_splits_ != nullptr && larger_leaf_splits_->leaf_index() >= 0) {
leaf = larger_leaf_splits_->LeafIndex(); leaf = larger_leaf_splits_->leaf_index();
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best); auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best);
best_split_per_leaf_[leaf] = larger_best[larger_best_idx]; best_split_per_leaf_[leaf] = larger_best[larger_best_idx];
} }
...@@ -692,10 +679,11 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri ...@@ -692,10 +679,11 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
static_cast<float>(best_split_info.gain), static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type()); train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
} }
#ifdef DEBUG
CHECK(*right_leaf == next_leaf_id); CHECK(*right_leaf == next_leaf_id);
#endif
auto p_left = smaller_leaf_splits_.get();
auto p_right = larger_leaf_splits_.get();
// init the leaves that used on next iteration // init the leaves that used on next iteration
if (best_split_info.left_count < best_split_info.right_count) { if (best_split_info.left_count < best_split_info.right_count) {
CHECK(best_split_info.left_count > 0); CHECK(best_split_info.left_count > 0);
...@@ -705,21 +693,11 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri ...@@ -705,21 +693,11 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
CHECK(best_split_info.right_count > 0); CHECK(best_split_info.right_count > 0);
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian); smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian); larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
p_right = smaller_leaf_splits_.get();
p_left = larger_leaf_splits_.get();
}
p_left->SetValueConstraint(best_split_info.min_constraint, best_split_info.max_constraint);
p_right->SetValueConstraint(best_split_info.min_constraint, best_split_info.max_constraint);
if (is_numerical_split) {
double mid = (best_split_info.left_output + best_split_info.right_output) / 2.0f;
if (best_split_info.monotone_type < 0) {
p_left->SetValueConstraint(mid, best_split_info.max_constraint);
p_right->SetValueConstraint(best_split_info.min_constraint, mid);
} else if (best_split_info.monotone_type > 0) {
p_left->SetValueConstraint(best_split_info.min_constraint, mid);
p_right->SetValueConstraint(mid, best_split_info.max_constraint);
}
} }
constraints_->UpdateConstraints(
is_numerical_split, *left_leaf, *right_leaf,
best_split_info.monotone_type, best_split_info.right_output,
best_split_info.left_output);
} }
...@@ -763,4 +741,26 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj ...@@ -763,4 +741,26 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj
} }
} }
void SerialTreeLearner::ComputeBestSplitForFeature(
FeatureHistogram* histogram_array_, int feature_index, int real_fidx,
bool is_feature_used, int num_data, const LeafSplits* leaf_splits,
SplitInfo* best_split) {
if (!is_feature_used) {
return;
}
SplitInfo new_split;
histogram_array_[feature_index].FindBestThreshold(
leaf_splits->sum_gradients(), leaf_splits->sum_hessians(), num_data,
constraints_->Get(leaf_splits->leaf_index()), &new_split);
new_split.feature = real_fidx;
if (cegb_ != nullptr) {
new_split.gain -=
cegb_->DetlaGain(feature_index, real_fidx, leaf_splits->leaf_index(),
num_data, new_split);
}
if (new_split > *best_split) {
*best_split = new_split;
}
}
} // namespace LightGBM } // namespace LightGBM
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "feature_histogram.hpp" #include "feature_histogram.hpp"
#include "leaf_splits.hpp" #include "leaf_splits.hpp"
#include "split_info.hpp" #include "split_info.hpp"
#include "monotone_constraints.hpp"
#ifdef USE_GPU #ifdef USE_GPU
// Use 4KBytes aligned allocator for ordered gradients and ordered hessians when GPU is enabled. // Use 4KBytes aligned allocator for ordered gradients and ordered hessians when GPU is enabled.
...@@ -82,6 +83,12 @@ class SerialTreeLearner: public TreeLearner { ...@@ -82,6 +83,12 @@ class SerialTreeLearner: public TreeLearner {
bool IsHistColWise() const override { return is_hist_colwise_; } bool IsHistColWise() const override { return is_hist_colwise_; }
protected: protected:
void ComputeBestSplitForFeature(FeatureHistogram* histogram_array_,
int feature_index, int real_fidx,
bool is_feature_used, int num_data,
const LeafSplits* leaf_splits,
SplitInfo* best_split);
void GetMultiValBin(const Dataset* dataset, bool is_first_time); void GetMultiValBin(const Dataset* dataset, bool is_first_time);
virtual std::vector<int8_t> GetUsedFeatures(bool is_tree_level); virtual std::vector<int8_t> GetUsedFeatures(bool is_tree_level);
...@@ -151,6 +158,8 @@ class SerialTreeLearner: public TreeLearner { ...@@ -151,6 +158,8 @@ class SerialTreeLearner: public TreeLearner {
std::vector<SplitInfo> best_split_per_leaf_; std::vector<SplitInfo> best_split_per_leaf_;
/*! \brief store best split per feature for all leaves */ /*! \brief store best split per feature for all leaves */
std::vector<SplitInfo> splits_per_leaf_; std::vector<SplitInfo> splits_per_leaf_;
// Stores minimum and maximum constraints for each leaf
std::unique_ptr<LeafConstraints<ConstraintEntry>> constraints_;
/*! \brief stores best thresholds for all feature for smaller leaf */ /*! \brief stores best thresholds for all feature for smaller leaf */
std::unique_ptr<LeafSplits> smaller_leaf_splits_; std::unique_ptr<LeafSplits> smaller_leaf_splits_;
......
...@@ -48,8 +48,6 @@ struct SplitInfo { ...@@ -48,8 +48,6 @@ struct SplitInfo {
/*! \brief True if default split is left */ /*! \brief True if default split is left */
bool default_left = true; bool default_left = true;
int8_t monotone_type = 0; int8_t monotone_type = 0;
double min_constraint = -std::numeric_limits<double>::max();
double max_constraint = std::numeric_limits<double>::max();
inline static int Size(int max_cat_threshold) { inline static int Size(int max_cat_threshold) {
return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 9 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t) + sizeof(int8_t); return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 9 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t) + sizeof(int8_t);
} }
...@@ -81,10 +79,6 @@ struct SplitInfo { ...@@ -81,10 +79,6 @@ struct SplitInfo {
buffer += sizeof(default_left); buffer += sizeof(default_left);
std::memcpy(buffer, &monotone_type, sizeof(monotone_type)); std::memcpy(buffer, &monotone_type, sizeof(monotone_type));
buffer += sizeof(monotone_type); buffer += sizeof(monotone_type);
std::memcpy(buffer, &min_constraint, sizeof(min_constraint));
buffer += sizeof(min_constraint);
std::memcpy(buffer, &max_constraint, sizeof(max_constraint));
buffer += sizeof(max_constraint);
std::memcpy(buffer, &num_cat_threshold, sizeof(num_cat_threshold)); std::memcpy(buffer, &num_cat_threshold, sizeof(num_cat_threshold));
buffer += sizeof(num_cat_threshold); buffer += sizeof(num_cat_threshold);
std::memcpy(buffer, cat_threshold.data(), sizeof(uint32_t) * num_cat_threshold); std::memcpy(buffer, cat_threshold.data(), sizeof(uint32_t) * num_cat_threshold);
...@@ -117,10 +111,6 @@ struct SplitInfo { ...@@ -117,10 +111,6 @@ struct SplitInfo {
buffer += sizeof(default_left); buffer += sizeof(default_left);
std::memcpy(&monotone_type, buffer, sizeof(monotone_type)); std::memcpy(&monotone_type, buffer, sizeof(monotone_type));
buffer += sizeof(monotone_type); buffer += sizeof(monotone_type);
std::memcpy(&min_constraint, buffer, sizeof(min_constraint));
buffer += sizeof(min_constraint);
std::memcpy(&max_constraint, buffer, sizeof(max_constraint));
buffer += sizeof(max_constraint);
std::memcpy(&num_cat_threshold, buffer, sizeof(num_cat_threshold)); std::memcpy(&num_cat_threshold, buffer, sizeof(num_cat_threshold));
buffer += sizeof(num_cat_threshold); buffer += sizeof(num_cat_threshold);
cat_threshold.resize(num_cat_threshold); cat_threshold.resize(num_cat_threshold);
......
...@@ -292,16 +292,13 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() { ...@@ -292,16 +292,13 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(), this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(),
this->smaller_leaf_histogram_array_[feature_index].RawData()); this->smaller_leaf_histogram_array_[feature_index].RawData());
this->smaller_leaf_histogram_array_[feature_index].FindBestThreshold( this->ComputeBestSplitForFeature(
this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_histogram_array_, feature_index, real_feature_index,
this->smaller_leaf_splits_->sum_hessians(), true, this->smaller_leaf_splits_->num_data_in_leaf(),
this->smaller_leaf_splits_->num_data_in_leaf(), this->smaller_leaf_splits_.get(),
this->smaller_leaf_splits_->min_constraint(),
this->smaller_leaf_splits_->max_constraint(),
&smaller_bestsplit_per_features[feature_index]); &smaller_bestsplit_per_features[feature_index]);
smaller_bestsplit_per_features[feature_index].feature = real_feature_index;
// only has root leaf // only has root leaf
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->LeafIndex() < 0) { continue; } if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->leaf_index() < 0) { continue; }
if (use_subtract) { if (use_subtract) {
this->larger_leaf_histogram_array_[feature_index].Subtract(this->smaller_leaf_histogram_array_[feature_index]); this->larger_leaf_histogram_array_[feature_index].Subtract(this->smaller_leaf_histogram_array_[feature_index]);
...@@ -309,15 +306,11 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() { ...@@ -309,15 +306,11 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
this->train_data_->FixHistogram(feature_index, this->larger_leaf_splits_->sum_gradients(), this->larger_leaf_splits_->sum_hessians(), this->train_data_->FixHistogram(feature_index, this->larger_leaf_splits_->sum_gradients(), this->larger_leaf_splits_->sum_hessians(),
this->larger_leaf_histogram_array_[feature_index].RawData()); this->larger_leaf_histogram_array_[feature_index].RawData());
} }
// find best threshold for larger child this->ComputeBestSplitForFeature(
this->larger_leaf_histogram_array_[feature_index].FindBestThreshold( this->larger_leaf_histogram_array_, feature_index, real_feature_index,
this->larger_leaf_splits_->sum_gradients(), true, this->larger_leaf_splits_->num_data_in_leaf(),
this->larger_leaf_splits_->sum_hessians(), this->larger_leaf_splits_.get(),
this->larger_leaf_splits_->num_data_in_leaf(),
this->larger_leaf_splits_->min_constraint(),
this->larger_leaf_splits_->max_constraint(),
&larger_bestsplit_per_features[feature_index]); &larger_bestsplit_per_features[feature_index]);
larger_bestsplit_per_features[feature_index].feature = real_feature_index;
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
...@@ -359,8 +352,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() { ...@@ -359,8 +352,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
} }
// global voting // global voting
std::vector<int> smaller_top_features, larger_top_features; std::vector<int> smaller_top_features, larger_top_features;
GlobalVoting(this->smaller_leaf_splits_->LeafIndex(), smaller_top_k_splits_global, &smaller_top_features); GlobalVoting(this->smaller_leaf_splits_->leaf_index(), smaller_top_k_splits_global, &smaller_top_features);
GlobalVoting(this->larger_leaf_splits_->LeafIndex(), larger_top_k_splits_global, &larger_top_features); GlobalVoting(this->larger_leaf_splits_->leaf_index(), larger_top_k_splits_global, &larger_top_features);
// copy local histgrams to buffer // copy local histgrams to buffer
CopyLocalHistogram(smaller_top_features, larger_top_features); CopyLocalHistogram(smaller_top_features, larger_top_features);
...@@ -390,7 +383,6 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons ...@@ -390,7 +383,6 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
const int real_feature_index = this->train_data_->RealFeatureIndex(feature_index); const int real_feature_index = this->train_data_->RealFeatureIndex(feature_index);
if (smaller_is_feature_aggregated_[feature_index]) { if (smaller_is_feature_aggregated_[feature_index]) {
SplitInfo smaller_split;
// restore from buffer // restore from buffer
smaller_leaf_histogram_array_global_[feature_index].FromMemory( smaller_leaf_histogram_array_global_[feature_index].FromMemory(
output_buffer_.data() + smaller_buffer_read_start_pos_[feature_index]); output_buffer_.data() + smaller_buffer_read_start_pos_[feature_index]);
...@@ -399,22 +391,15 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons ...@@ -399,22 +391,15 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
smaller_leaf_splits_global_->sum_gradients(), smaller_leaf_splits_global_->sum_hessians(), smaller_leaf_splits_global_->sum_gradients(), smaller_leaf_splits_global_->sum_hessians(),
smaller_leaf_histogram_array_global_[feature_index].RawData()); smaller_leaf_histogram_array_global_[feature_index].RawData());
// find best threshold this->ComputeBestSplitForFeature(
smaller_leaf_histogram_array_global_[feature_index].FindBestThreshold( smaller_leaf_histogram_array_global_.get(), feature_index,
smaller_leaf_splits_global_->sum_gradients(), real_feature_index, smaller_node_used_features[feature_index],
smaller_leaf_splits_global_->sum_hessians(), GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->leaf_index()),
GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->LeafIndex()), smaller_leaf_splits_global_.get(), &smaller_bests_per_thread[tid]);
smaller_leaf_splits_global_->min_constraint(),
smaller_leaf_splits_global_->max_constraint(),
&smaller_split);
smaller_split.feature = real_feature_index;
if (smaller_split > smaller_bests_per_thread[tid] && smaller_node_used_features[feature_index]) {
smaller_bests_per_thread[tid] = smaller_split;
}
} }
if (larger_is_feature_aggregated_[feature_index]) { if (larger_is_feature_aggregated_[feature_index]) {
SplitInfo larger_split;
// restore from buffer // restore from buffer
larger_leaf_histogram_array_global_[feature_index].FromMemory(output_buffer_.data() + larger_buffer_read_start_pos_[feature_index]); larger_leaf_histogram_array_global_[feature_index].FromMemory(output_buffer_.data() + larger_buffer_read_start_pos_[feature_index]);
...@@ -422,47 +407,42 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons ...@@ -422,47 +407,42 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
larger_leaf_splits_global_->sum_gradients(), larger_leaf_splits_global_->sum_hessians(), larger_leaf_splits_global_->sum_gradients(), larger_leaf_splits_global_->sum_hessians(),
larger_leaf_histogram_array_global_[feature_index].RawData()); larger_leaf_histogram_array_global_[feature_index].RawData());
// find best threshold this->ComputeBestSplitForFeature(
larger_leaf_histogram_array_global_[feature_index].FindBestThreshold( larger_leaf_histogram_array_global_.get(), feature_index,
larger_leaf_splits_global_->sum_gradients(), real_feature_index,
larger_leaf_splits_global_->sum_hessians(), larger_node_used_features[feature_index],
GetGlobalDataCountInLeaf(larger_leaf_splits_global_->LeafIndex()), GetGlobalDataCountInLeaf(larger_leaf_splits_global_->leaf_index()),
larger_leaf_splits_global_->min_constraint(), larger_leaf_splits_global_.get(),
larger_leaf_splits_global_->max_constraint(), &larger_best_per_thread[tid]);
&larger_split);
larger_split.feature = real_feature_index;
if (larger_split > larger_best_per_thread[tid] && larger_node_used_features[feature_index]) {
larger_best_per_thread[tid] = larger_split;
}
} }
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_bests_per_thread); auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_bests_per_thread);
int leaf = this->smaller_leaf_splits_->LeafIndex(); int leaf = this->smaller_leaf_splits_->leaf_index();
this->best_split_per_leaf_[leaf] = smaller_bests_per_thread[smaller_best_idx]; this->best_split_per_leaf_[leaf] = smaller_bests_per_thread[smaller_best_idx];
if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->leaf_index() >= 0) {
leaf = this->larger_leaf_splits_->LeafIndex(); leaf = this->larger_leaf_splits_->leaf_index();
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best_per_thread); auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best_per_thread);
this->best_split_per_leaf_[leaf] = larger_best_per_thread[larger_best_idx]; this->best_split_per_leaf_[leaf] = larger_best_per_thread[larger_best_idx];
} }
// find local best // find local best
SplitInfo smaller_best_split, larger_best_split; SplitInfo smaller_best_split, larger_best_split;
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()]; smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()];
// find local best split for larger leaf // find local best split for larger leaf
if (this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_->leaf_index() >= 0) {
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()]; larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()];
} }
// sync global best info // sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold); SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
// copy back // copy back
this->best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best_split; this->best_split_per_leaf_[smaller_leaf_splits_global_->leaf_index()] = smaller_best_split;
if (larger_best_split.feature >= 0 && larger_leaf_splits_global_->LeafIndex() >= 0) { if (larger_best_split.feature >= 0 && larger_leaf_splits_global_->leaf_index() >= 0) {
this->best_split_per_leaf_[larger_leaf_splits_global_->LeafIndex()] = larger_best_split; this->best_split_per_leaf_[larger_leaf_splits_global_->leaf_index()] = larger_best_split;
} }
} }
...@@ -473,8 +453,6 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, ...@@ -473,8 +453,6 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf,
// set the global number of data for leaves // set the global number of data for leaves
global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count; global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count;
global_data_count_in_leaf_[*right_leaf] = best_split_info.right_count; global_data_count_in_leaf_[*right_leaf] = best_split_info.right_count;
auto p_left = smaller_leaf_splits_global_.get();
auto p_right = larger_leaf_splits_global_.get();
// init the global sumup info // init the global sumup info
if (best_split_info.left_count < best_split_info.right_count) { if (best_split_info.left_count < best_split_info.right_count) {
smaller_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(), smaller_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(),
...@@ -490,22 +468,6 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, ...@@ -490,22 +468,6 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf,
larger_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(), larger_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(),
best_split_info.left_sum_gradient, best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian); best_split_info.left_sum_hessian);
p_left = larger_leaf_splits_global_.get();
p_right = smaller_leaf_splits_global_.get();
}
const int inner_feature_index = this->train_data_->InnerFeatureIndex(best_split_info.feature);
bool is_numerical_split = this->train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin;
p_left->SetValueConstraint(best_split_info.min_constraint, best_split_info.max_constraint);
p_right->SetValueConstraint(best_split_info.min_constraint, best_split_info.max_constraint);
if (is_numerical_split) {
double mid = (best_split_info.left_output + best_split_info.right_output) / 2.0f;
if (best_split_info.monotone_type < 0) {
p_left->SetValueConstraint(mid, best_split_info.max_constraint);
p_right->SetValueConstraint(best_split_info.min_constraint, mid);
} else if (best_split_info.monotone_type > 0) {
p_left->SetValueConstraint(best_split_info.min_constraint, mid);
p_right->SetValueConstraint(mid, best_split_info.max_constraint);
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment