Unverified Commit 3670e476 authored by CharlesAuguste's avatar CharlesAuguste Committed by GitHub
Browse files

Refactoring monotone constraints (linked to #2305) (#2717)



* Move monotone constraints to the monotone_constraints files.

* Add checks for debug mode.

* Refactored FindBestSplitsFromHistograms.

* Add headers.

* fix

* Update data_parallel_tree_learner.cpp

* simplify ComputeBestSplitForFeature

* Fix min / max issue.

* Remove duplicated check.
Co-authored-by: default avatarGuolin Ke <guolin.ke@outlook.com>
parent d8a34df9
......@@ -562,6 +562,9 @@ std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_ty
for (int iter = 0; iter < num_used_model; ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
if (models_[iter]->split_gain(split_idx) > 0) {
#ifdef DEBUG
CHECK(models_[iter]->split_feature(split_idx) >= 0);
#endif
feature_importances[models_[iter]->split_feature(split_idx)] += 1.0;
}
}
......@@ -570,6 +573,9 @@ std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_ty
for (int iter = 0; iter < num_used_model; ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
if (models_[iter]->split_gain(split_idx) > 0) {
#ifdef DEBUG
CHECK(models_[iter]->split_feature(split_idx) >= 0);
#endif
feature_importances[models_[iter]->split_feature(split_idx)] += models_[iter]->split_gain(split_idx);
}
}
......
......@@ -187,67 +187,55 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
this->train_data_->FixHistogram(feature_index,
this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(),
this->smaller_leaf_histogram_array_[feature_index].RawData());
SplitInfo smaller_split;
// find best threshold for smaller child
this->smaller_leaf_histogram_array_[feature_index].FindBestThreshold(
this->smaller_leaf_splits_->sum_gradients(),
this->smaller_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->LeafIndex()),
this->smaller_leaf_splits_->min_constraint(),
this->smaller_leaf_splits_->max_constraint(),
&smaller_split);
smaller_split.feature = real_feature_index;
if (smaller_split > smaller_bests_per_thread[tid] && smaller_node_used_features[feature_index]) {
smaller_bests_per_thread[tid] = smaller_split;
}
this->ComputeBestSplitForFeature(
this->smaller_leaf_histogram_array_, feature_index, real_feature_index,
smaller_node_used_features[feature_index],
GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->leaf_index()),
this->smaller_leaf_splits_.get(),
&smaller_bests_per_thread[tid]);
// only root leaf
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->LeafIndex() < 0) continue;
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->leaf_index() < 0) continue;
// construct histgroms for large leaf, we init larger leaf as the parent, so we can just subtract the smaller leaf's histograms
this->larger_leaf_histogram_array_[feature_index].Subtract(
this->smaller_leaf_histogram_array_[feature_index]);
SplitInfo larger_split;
// find best threshold for larger child
this->larger_leaf_histogram_array_[feature_index].FindBestThreshold(
this->larger_leaf_splits_->sum_gradients(),
this->larger_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->larger_leaf_splits_->LeafIndex()),
this->larger_leaf_splits_->min_constraint(),
this->larger_leaf_splits_->max_constraint(),
&larger_split);
larger_split.feature = real_feature_index;
if (larger_split > larger_bests_per_thread[tid] && larger_node_used_features[feature_index]) {
larger_bests_per_thread[tid] = larger_split;
}
this->ComputeBestSplitForFeature(
this->larger_leaf_histogram_array_, feature_index, real_feature_index,
larger_node_used_features[feature_index],
GetGlobalDataCountInLeaf(this->larger_leaf_splits_->leaf_index()),
this->larger_leaf_splits_.get(),
&larger_bests_per_thread[tid]);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_bests_per_thread);
int leaf = this->smaller_leaf_splits_->LeafIndex();
int leaf = this->smaller_leaf_splits_->leaf_index();
this->best_split_per_leaf_[leaf] = smaller_bests_per_thread[smaller_best_idx];
if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->LeafIndex() >= 0) {
leaf = this->larger_leaf_splits_->LeafIndex();
if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->leaf_index() >= 0) {
leaf = this->larger_leaf_splits_->leaf_index();
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_bests_per_thread);
this->best_split_per_leaf_[leaf] = larger_bests_per_thread[larger_best_idx];
}
SplitInfo smaller_best_split, larger_best_split;
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()];
// find local best split for larger leaf
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
if (this->larger_leaf_splits_->leaf_index() >= 0) {
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()];
}
// sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
// set best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()] = larger_best_split;
this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()] = smaller_best_split;
if (this->larger_leaf_splits_->leaf_index() >= 0) {
this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()] = larger_best_split;
}
}
......
......@@ -17,6 +17,7 @@
#include <vector>
#include "split_info.hpp"
#include "monotone_constraints.hpp"
namespace LightGBM {
......@@ -58,11 +59,11 @@ class FeatureHistogram {
meta_ = meta;
data_ = data;
if (meta_->bin_type == BinType::NumericalBin) {
find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdNumerical, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6);
find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdNumerical, this, std::placeholders::_1
, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5);
} else {
find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdCategorical, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6);
find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdCategorical, this, std::placeholders::_1
, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5);
}
rand_ = Random(meta_->config->extra_seed);
}
......@@ -81,15 +82,15 @@ class FeatureHistogram {
}
void FindBestThreshold(double sum_gradient, double sum_hessian, data_size_t num_data,
double min_constraint, double max_constraint, SplitInfo* output) {
const ConstraintEntry& constraints, SplitInfo* output) {
output->default_left = true;
output->gain = kMinScore;
find_best_threshold_fun_(sum_gradient, sum_hessian + 2 * kEpsilon, num_data, min_constraint, max_constraint, output);
find_best_threshold_fun_(sum_gradient, sum_hessian + 2 * kEpsilon, num_data, constraints, output);
output->gain *= meta_->penalty;
}
void FindBestThresholdNumerical(double sum_gradient, double sum_hessian, data_size_t num_data,
double min_constraint, double max_constraint, SplitInfo* output) {
const ConstraintEntry& constraints, SplitInfo* output) {
is_splittable_ = false;
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step);
......@@ -102,26 +103,26 @@ class FeatureHistogram {
if (meta_->num_bin > 2 && meta_->missing_type != MissingType::None) {
if (meta_->missing_type == MissingType::Zero) {
if (is_rand) {
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, true, false, rand_threshold);
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, true, false, rand_threshold);
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, -1, true, false, rand_threshold);
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, 1, true, false, rand_threshold);
} else {
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, true, false, rand_threshold);
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, true, false, rand_threshold);
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, -1, true, false, rand_threshold);
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, 1, true, false, rand_threshold);
}
} else {
if (is_rand) {
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, true, rand_threshold);
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, false, true, rand_threshold);
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, -1, false, true, rand_threshold);
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, 1, false, true, rand_threshold);
} else {
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, true, rand_threshold);
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, false, true, rand_threshold);
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, -1, false, true, rand_threshold);
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, 1, false, true, rand_threshold);
}
}
} else {
if (is_rand) {
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, false, rand_threshold);
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, -1, false, false, rand_threshold);
} else {
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, false, rand_threshold);
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, output, -1, false, false, rand_threshold);
}
// fix the direction error when only have 2 bins
if (meta_->missing_type == MissingType::NaN) {
......@@ -130,12 +131,10 @@ class FeatureHistogram {
}
output->gain -= min_gain_shift;
output->monotone_type = meta_->monotone_type;
output->min_constraint = min_constraint;
output->max_constraint = max_constraint;
}
void FindBestThresholdCategorical(double sum_gradient, double sum_hessian, data_size_t num_data,
double min_constraint, double max_constraint, SplitInfo* output) {
const ConstraintEntry& constraints, SplitInfo* output) {
output->default_left = false;
double best_gain = kMinScore;
data_size_t best_left_count = 0;
......@@ -173,7 +172,7 @@ class FeatureHistogram {
double sum_other_gradient = sum_gradient - grad;
// current split gain
double current_gain = GetSplitGains(sum_other_gradient, sum_other_hessian, grad, hess + kEpsilon,
meta_->config->lambda_l1, l2, meta_->config->max_delta_step, min_constraint, max_constraint, 0);
meta_->config->lambda_l1, l2, meta_->config->max_delta_step, constraints, 0);
// gain with split is worse than without split
if (current_gain <= min_gain_shift) continue;
......@@ -253,7 +252,7 @@ class FeatureHistogram {
double sum_right_gradient = sum_gradient - sum_left_gradient;
if (!meta_->config->extra_trees || i == rand_threshold) {
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->config->lambda_l1, l2, meta_->config->max_delta_step, min_constraint, max_constraint, 0);
meta_->config->lambda_l1, l2, meta_->config->max_delta_step, constraints, 0);
if (current_gain <= min_gain_shift) continue;
is_splittable_ = true;
if (current_gain > best_gain) {
......@@ -271,13 +270,13 @@ class FeatureHistogram {
if (is_splittable_) {
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian,
meta_->config->lambda_l1, l2, meta_->config->max_delta_step, min_constraint, max_constraint);
meta_->config->lambda_l1, l2, meta_->config->max_delta_step, constraints);
output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(
sum_gradient - best_sum_left_gradient, sum_hessian - best_sum_left_hessian,
meta_->config->lambda_l1, l2, meta_->config->max_delta_step, min_constraint, max_constraint);
meta_->config->lambda_l1, l2, meta_->config->max_delta_step, constraints);
output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon;
......@@ -301,8 +300,6 @@ class FeatureHistogram {
}
}
output->monotone_type = 0;
output->min_constraint = min_constraint;
output->max_constraint = max_constraint;
}
}
......@@ -481,9 +478,9 @@ class FeatureHistogram {
static double GetSplitGains(double sum_left_gradients, double sum_left_hessians,
double sum_right_gradients, double sum_right_hessians,
double l1, double l2, double max_delta_step,
double min_constraint, double max_constraint, int8_t monotone_constraint) {
double left_output = CalculateSplittedLeafOutput(sum_left_gradients, sum_left_hessians, l1, l2, max_delta_step, min_constraint, max_constraint);
double right_output = CalculateSplittedLeafOutput(sum_right_gradients, sum_right_hessians, l1, l2, max_delta_step, min_constraint, max_constraint);
const ConstraintEntry& constraints, int8_t monotone_constraint) {
double left_output = CalculateSplittedLeafOutput(sum_left_gradients, sum_left_hessians, l1, l2, max_delta_step, constraints);
double right_output = CalculateSplittedLeafOutput(sum_right_gradients, sum_right_hessians, l1, l2, max_delta_step, constraints);
if (((monotone_constraint > 0) && (left_output > right_output)) ||
((monotone_constraint < 0) && (left_output < right_output))) {
return 0;
......@@ -498,13 +495,14 @@ class FeatureHistogram {
* \param sum_hessians
* \return leaf output
*/
static double CalculateSplittedLeafOutput(double sum_gradients, double sum_hessians, double l1, double l2, double max_delta_step,
double min_constraint, double max_constraint) {
static double
CalculateSplittedLeafOutput(double sum_gradients, double sum_hessians,
double l1, double l2, double max_delta_step, const ConstraintEntry& constraints) {
double ret = CalculateSplittedLeafOutput(sum_gradients, sum_hessians, l1, l2, max_delta_step);
if (ret < min_constraint) {
ret = min_constraint;
} else if (ret > max_constraint) {
ret = max_constraint;
if (ret < constraints.min) {
ret = constraints.min;
} else if (ret > constraints.max) {
ret = constraints.max;
}
return ret;
}
......@@ -526,7 +524,7 @@ class FeatureHistogram {
}
template<bool is_rand>
void FindBestThresholdSequence(double sum_gradient, double sum_hessian, data_size_t num_data, double min_constraint, double max_constraint,
void FindBestThresholdSequence(double sum_gradient, double sum_hessian, data_size_t num_data, const ConstraintEntry& constraints,
double min_gain_shift, SplitInfo* output, int dir, bool skip_default_bin, bool use_na_as_missing, int rand_threshold) {
const int8_t offset = meta_->offset;
......@@ -571,7 +569,7 @@ class FeatureHistogram {
// current split gain
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint, meta_->monotone_type);
constraints, meta_->monotone_type);
// gain with split is worse than without split
if (current_gain <= min_gain_shift) continue;
......@@ -635,7 +633,7 @@ class FeatureHistogram {
// current split gain
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint, meta_->monotone_type);
constraints, meta_->monotone_type);
// gain with split is worse than without split
if (current_gain <= min_gain_shift) continue;
......@@ -656,16 +654,17 @@ class FeatureHistogram {
if (is_splittable_ && best_gain > output->gain) {
// update split information
output->threshold = best_threshold;
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint);
output->left_output = CalculateSplittedLeafOutput(
best_sum_left_gradient, best_sum_left_hessian,
meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->config->max_delta_step, constraints);
output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(
sum_gradient - best_sum_left_gradient, sum_hessian - best_sum_left_hessian,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint);
constraints);
output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon;
......@@ -681,7 +680,9 @@ class FeatureHistogram {
/*! \brief random number generator for extremely randomized trees */
Random rand_;
std::function<void(double, double, data_size_t, double, double, SplitInfo*)> find_best_threshold_fun_;
std::function<void(double, double, data_size_t, const ConstraintEntry&,
SplitInfo*)>
find_best_threshold_fun_;
};
class HistogramPool {
public:
......
......@@ -56,17 +56,17 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(con
TREELEARNER_T::FindBestSplitsFromHistograms(is_feature_used, use_subtract);
SplitInfo smaller_best_split, larger_best_split;
// get best split at smaller leaf
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()];
// find local best split for larger leaf
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
if (this->larger_leaf_splits_->leaf_index() >= 0) {
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()];
}
// sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
// update best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()] = larger_best_split;
this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()] = smaller_best_split;
if (this->larger_leaf_splits_->leaf_index() >= 0) {
this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()] = larger_best_split;
}
}
......
......@@ -1089,14 +1089,8 @@ void GPUTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right
Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf());
}
} else {
double smaller_min = smaller_leaf_splits_->min_constraint();
double smaller_max = smaller_leaf_splits_->max_constraint();
double larger_min = larger_leaf_splits_->min_constraint();
double larger_max = larger_leaf_splits_->max_constraint();
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
smaller_leaf_splits_->SetValueConstraint(smaller_min, smaller_max);
larger_leaf_splits_->SetValueConstraint(larger_min, larger_max);
if ((best_split_info.left_count != larger_leaf_splits_->num_data_in_leaf()) ||
(best_split_info.right_count!= smaller_leaf_splits_->num_data_in_leaf())) {
Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf());
......
......@@ -43,16 +43,8 @@ class LeafSplits {
data_indices_ = data_partition->GetIndexOnLeaf(leaf, &num_data_in_leaf_);
sum_gradients_ = sum_gradients;
sum_hessians_ = sum_hessians;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
}
void SetValueConstraint(double min, double max) {
min_val_ = min;
max_val_ = max;
}
/*!
* \brief Init splits on current leaf, it will traverse all data to sum up the results
* \param gradients
......@@ -71,8 +63,6 @@ class LeafSplits {
}
sum_gradients_ = tmp_sum_gradients;
sum_hessians_ = tmp_sum_hessians;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
}
/*!
......@@ -95,8 +85,6 @@ class LeafSplits {
}
sum_gradients_ = tmp_sum_gradients;
sum_hessians_ = tmp_sum_hessians;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
}
......@@ -109,8 +97,6 @@ class LeafSplits {
leaf_index_ = 0;
sum_gradients_ = sum_gradients;
sum_hessians_ = sum_hessians;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
}
/*!
......@@ -120,13 +106,11 @@ class LeafSplits {
leaf_index_ = -1;
data_indices_ = nullptr;
num_data_in_leaf_ = 0;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
}
/*! \brief Get current leaf index */
int LeafIndex() const { return leaf_index_; }
int leaf_index() const { return leaf_index_; }
/*! \brief Get numer of data in current leaf */
data_size_t num_data_in_leaf() const { return num_data_in_leaf_; }
......@@ -137,10 +121,6 @@ class LeafSplits {
/*! \brief Get sum of hessians of current leaf */
double sum_hessians() const { return sum_hessians_; }
double max_constraint() const { return max_val_; }
double min_constraint() const { return min_val_; }
/*! \brief Get indices of data of current leaf */
const data_size_t* data_indices() const { return data_indices_; }
......@@ -158,8 +138,6 @@ class LeafSplits {
double sum_hessians_;
/*! \brief indices of data of current leaf */
const data_size_t* data_indices_;
double min_val_;
double max_val_;
};
} // namespace LightGBM
......
#ifndef LIGHTGBM_TREELEARNER_MONOTONE_CONSTRAINTS_H_
#define LIGHTGBM_TREELEARNER_MONOTONE_CONSTRAINTS_H_
#include <algorithm>
#include <vector>
#include <cstdint>
#include <limits>
namespace LightGBM {
struct ConstraintEntry {
double min = -std::numeric_limits<double>::max();
double max = std::numeric_limits<double>::max();
ConstraintEntry(){};
void Reset() {
min = -std::numeric_limits<double>::max();
max = std::numeric_limits<double>::max();
}
void UpdateMin(double new_min) { min = std::max(new_min, min); }
void UpdateMax(double new_max) { max = std::min(new_max, max); }
};
template <typename ConstraintEntry>
class LeafConstraints {
public:
LeafConstraints(int num_leaves) : num_leaves_(num_leaves) {
entries_.resize(num_leaves_);
}
void Reset() {
for (auto& entry : entries_) {
entry.Reset();
}
}
void UpdateConstraints(bool is_numerical_split, int leaf, int new_leaf,
int8_t monotone_type, double right_output,
double left_output) {
entries_[new_leaf] = entries_[leaf];
if (is_numerical_split) {
double mid = (left_output + right_output) / 2.0f;
if (monotone_type < 0) {
entries_[leaf].UpdateMin(mid);
entries_[new_leaf].UpdateMax(mid);
} else if (monotone_type > 0) {
entries_[leaf].UpdateMax(mid);
entries_[new_leaf].UpdateMin(mid);
}
}
}
const ConstraintEntry& Get(int leaf_idx) const { return entries_[leaf_idx]; }
private:
int num_leaves_;
std::vector<ConstraintEntry> entries_;
};
} // namespace LightGBM
#endif // LightGBM_TREELEARNER_MONOTONE_CONSTRAINTS_H_
......@@ -53,6 +53,7 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian
// push split information for all leaves
best_split_per_leaf_.resize(config_->num_leaves);
constraints_.reset(new LeafConstraints<ConstraintEntry>(config_->num_leaves));
// initialize splits for leaf
smaller_leaf_splits_.reset(new LeafSplits(train_data_->num_data()));
......@@ -291,6 +292,8 @@ void SerialTreeLearner::BeforeTrain() {
// initialize data partition
data_partition_->Init();
constraints_->Reset();
// reset the splits for leaves
for (int i = 0; i < config_->num_leaves; ++i) {
best_split_per_leaf_[i].Reset();
......@@ -408,27 +411,19 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>&
OMP_LOOP_EX_BEGIN();
if (!is_feature_used[feature_index]) { continue; }
const int tid = omp_get_thread_num();
SplitInfo smaller_split;
train_data_->FixHistogram(feature_index,
smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians(),
smaller_leaf_histogram_array_[feature_index].RawData());
int real_fidx = train_data_->RealFeatureIndex(feature_index);
smaller_leaf_histogram_array_[feature_index].FindBestThreshold(
smaller_leaf_splits_->sum_gradients(),
smaller_leaf_splits_->sum_hessians(),
smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_->min_constraint(),
smaller_leaf_splits_->max_constraint(),
&smaller_split);
smaller_split.feature = real_fidx;
if (cegb_ != nullptr) {
smaller_split.gain -= cegb_->DetlaGain(feature_index, real_fidx, smaller_leaf_splits_->LeafIndex(), smaller_leaf_splits_->num_data_in_leaf(), smaller_split);
}
if (smaller_split > smaller_best[tid] && smaller_node_used_features[feature_index]) {
smaller_best[tid] = smaller_split;
}
ComputeBestSplitForFeature(smaller_leaf_histogram_array_, feature_index,
real_fidx,
smaller_node_used_features[feature_index],
smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_.get(), &smaller_best[tid]);
// only has root leaf
if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) { continue; }
if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->leaf_index() < 0) { continue; }
if (use_subtract) {
larger_leaf_histogram_array_[feature_index].Subtract(smaller_leaf_histogram_array_[feature_index]);
......@@ -436,31 +431,23 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>&
train_data_->FixHistogram(feature_index, larger_leaf_splits_->sum_gradients(), larger_leaf_splits_->sum_hessians(),
larger_leaf_histogram_array_[feature_index].RawData());
}
SplitInfo larger_split;
// find best threshold for larger child
larger_leaf_histogram_array_[feature_index].FindBestThreshold(
larger_leaf_splits_->sum_gradients(),
larger_leaf_splits_->sum_hessians(),
larger_leaf_splits_->num_data_in_leaf(),
larger_leaf_splits_->min_constraint(),
larger_leaf_splits_->max_constraint(),
&larger_split);
larger_split.feature = real_fidx;
if (cegb_ != nullptr) {
larger_split.gain -= cegb_->DetlaGain(feature_index, real_fidx, larger_leaf_splits_->LeafIndex(), larger_leaf_splits_->num_data_in_leaf(), larger_split);
}
if (larger_split > larger_best[tid] && larger_node_used_features[feature_index]) {
larger_best[tid] = larger_split;
}
ComputeBestSplitForFeature(larger_leaf_histogram_array_, feature_index,
real_fidx,
larger_node_used_features[feature_index],
larger_leaf_splits_->num_data_in_leaf(),
larger_leaf_splits_.get(),
&larger_best[tid]);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_best);
int leaf = smaller_leaf_splits_->LeafIndex();
int leaf = smaller_leaf_splits_->leaf_index();
best_split_per_leaf_[leaf] = smaller_best[smaller_best_idx];
if (larger_leaf_splits_ != nullptr && larger_leaf_splits_->LeafIndex() >= 0) {
leaf = larger_leaf_splits_->LeafIndex();
if (larger_leaf_splits_ != nullptr && larger_leaf_splits_->leaf_index() >= 0) {
leaf = larger_leaf_splits_->leaf_index();
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best);
best_split_per_leaf_[leaf] = larger_best[larger_best_idx];
}
......@@ -692,10 +679,11 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
}
#ifdef DEBUG
CHECK(*right_leaf == next_leaf_id);
#endif
auto p_left = smaller_leaf_splits_.get();
auto p_right = larger_leaf_splits_.get();
// init the leaves that used on next iteration
if (best_split_info.left_count < best_split_info.right_count) {
CHECK(best_split_info.left_count > 0);
......@@ -705,21 +693,11 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
CHECK(best_split_info.right_count > 0);
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
p_right = smaller_leaf_splits_.get();
p_left = larger_leaf_splits_.get();
}
p_left->SetValueConstraint(best_split_info.min_constraint, best_split_info.max_constraint);
p_right->SetValueConstraint(best_split_info.min_constraint, best_split_info.max_constraint);
if (is_numerical_split) {
double mid = (best_split_info.left_output + best_split_info.right_output) / 2.0f;
if (best_split_info.monotone_type < 0) {
p_left->SetValueConstraint(mid, best_split_info.max_constraint);
p_right->SetValueConstraint(best_split_info.min_constraint, mid);
} else if (best_split_info.monotone_type > 0) {
p_left->SetValueConstraint(best_split_info.min_constraint, mid);
p_right->SetValueConstraint(mid, best_split_info.max_constraint);
}
}
constraints_->UpdateConstraints(
is_numerical_split, *left_leaf, *right_leaf,
best_split_info.monotone_type, best_split_info.right_output,
best_split_info.left_output);
}
......@@ -763,4 +741,26 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj
}
}
void SerialTreeLearner::ComputeBestSplitForFeature(
FeatureHistogram* histogram_array_, int feature_index, int real_fidx,
bool is_feature_used, int num_data, const LeafSplits* leaf_splits,
SplitInfo* best_split) {
if (!is_feature_used) {
return;
}
SplitInfo new_split;
histogram_array_[feature_index].FindBestThreshold(
leaf_splits->sum_gradients(), leaf_splits->sum_hessians(), num_data,
constraints_->Get(leaf_splits->leaf_index()), &new_split);
new_split.feature = real_fidx;
if (cegb_ != nullptr) {
new_split.gain -=
cegb_->DetlaGain(feature_index, real_fidx, leaf_splits->leaf_index(),
num_data, new_split);
}
if (new_split > *best_split) {
*best_split = new_split;
}
}
} // namespace LightGBM
......@@ -22,6 +22,7 @@
#include "feature_histogram.hpp"
#include "leaf_splits.hpp"
#include "split_info.hpp"
#include "monotone_constraints.hpp"
#ifdef USE_GPU
// Use 4KBytes aligned allocator for ordered gradients and ordered hessians when GPU is enabled.
......@@ -82,6 +83,12 @@ class SerialTreeLearner: public TreeLearner {
bool IsHistColWise() const override { return is_hist_colwise_; }
protected:
void ComputeBestSplitForFeature(FeatureHistogram* histogram_array_,
int feature_index, int real_fidx,
bool is_feature_used, int num_data,
const LeafSplits* leaf_splits,
SplitInfo* best_split);
void GetMultiValBin(const Dataset* dataset, bool is_first_time);
virtual std::vector<int8_t> GetUsedFeatures(bool is_tree_level);
......@@ -151,6 +158,8 @@ class SerialTreeLearner: public TreeLearner {
std::vector<SplitInfo> best_split_per_leaf_;
/*! \brief store best split per feature for all leaves */
std::vector<SplitInfo> splits_per_leaf_;
// Stores minimum and maximum constraints for each leaf
std::unique_ptr<LeafConstraints<ConstraintEntry>> constraints_;
/*! \brief stores best thresholds for all feature for smaller leaf */
std::unique_ptr<LeafSplits> smaller_leaf_splits_;
......
......@@ -48,8 +48,6 @@ struct SplitInfo {
/*! \brief True if default split is left */
bool default_left = true;
int8_t monotone_type = 0;
double min_constraint = -std::numeric_limits<double>::max();
double max_constraint = std::numeric_limits<double>::max();
inline static int Size(int max_cat_threshold) {
return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 9 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t) + sizeof(int8_t);
}
......@@ -81,10 +79,6 @@ struct SplitInfo {
buffer += sizeof(default_left);
std::memcpy(buffer, &monotone_type, sizeof(monotone_type));
buffer += sizeof(monotone_type);
std::memcpy(buffer, &min_constraint, sizeof(min_constraint));
buffer += sizeof(min_constraint);
std::memcpy(buffer, &max_constraint, sizeof(max_constraint));
buffer += sizeof(max_constraint);
std::memcpy(buffer, &num_cat_threshold, sizeof(num_cat_threshold));
buffer += sizeof(num_cat_threshold);
std::memcpy(buffer, cat_threshold.data(), sizeof(uint32_t) * num_cat_threshold);
......@@ -117,10 +111,6 @@ struct SplitInfo {
buffer += sizeof(default_left);
std::memcpy(&monotone_type, buffer, sizeof(monotone_type));
buffer += sizeof(monotone_type);
std::memcpy(&min_constraint, buffer, sizeof(min_constraint));
buffer += sizeof(min_constraint);
std::memcpy(&max_constraint, buffer, sizeof(max_constraint));
buffer += sizeof(max_constraint);
std::memcpy(&num_cat_threshold, buffer, sizeof(num_cat_threshold));
buffer += sizeof(num_cat_threshold);
cat_threshold.resize(num_cat_threshold);
......
......@@ -292,16 +292,13 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(),
this->smaller_leaf_histogram_array_[feature_index].RawData());
this->smaller_leaf_histogram_array_[feature_index].FindBestThreshold(
this->smaller_leaf_splits_->sum_gradients(),
this->smaller_leaf_splits_->sum_hessians(),
this->smaller_leaf_splits_->num_data_in_leaf(),
this->smaller_leaf_splits_->min_constraint(),
this->smaller_leaf_splits_->max_constraint(),
&smaller_bestsplit_per_features[feature_index]);
smaller_bestsplit_per_features[feature_index].feature = real_feature_index;
this->ComputeBestSplitForFeature(
this->smaller_leaf_histogram_array_, feature_index, real_feature_index,
true, this->smaller_leaf_splits_->num_data_in_leaf(),
this->smaller_leaf_splits_.get(),
&smaller_bestsplit_per_features[feature_index]);
// only has root leaf
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->LeafIndex() < 0) { continue; }
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->leaf_index() < 0) { continue; }
if (use_subtract) {
this->larger_leaf_histogram_array_[feature_index].Subtract(this->smaller_leaf_histogram_array_[feature_index]);
......@@ -309,15 +306,11 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
this->train_data_->FixHistogram(feature_index, this->larger_leaf_splits_->sum_gradients(), this->larger_leaf_splits_->sum_hessians(),
this->larger_leaf_histogram_array_[feature_index].RawData());
}
// find best threshold for larger child
this->larger_leaf_histogram_array_[feature_index].FindBestThreshold(
this->larger_leaf_splits_->sum_gradients(),
this->larger_leaf_splits_->sum_hessians(),
this->larger_leaf_splits_->num_data_in_leaf(),
this->larger_leaf_splits_->min_constraint(),
this->larger_leaf_splits_->max_constraint(),
&larger_bestsplit_per_features[feature_index]);
larger_bestsplit_per_features[feature_index].feature = real_feature_index;
this->ComputeBestSplitForFeature(
this->larger_leaf_histogram_array_, feature_index, real_feature_index,
true, this->larger_leaf_splits_->num_data_in_leaf(),
this->larger_leaf_splits_.get(),
&larger_bestsplit_per_features[feature_index]);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
......@@ -359,8 +352,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
}
// global voting
std::vector<int> smaller_top_features, larger_top_features;
GlobalVoting(this->smaller_leaf_splits_->LeafIndex(), smaller_top_k_splits_global, &smaller_top_features);
GlobalVoting(this->larger_leaf_splits_->LeafIndex(), larger_top_k_splits_global, &larger_top_features);
GlobalVoting(this->smaller_leaf_splits_->leaf_index(), smaller_top_k_splits_global, &smaller_top_features);
GlobalVoting(this->larger_leaf_splits_->leaf_index(), larger_top_k_splits_global, &larger_top_features);
// copy local histgrams to buffer
CopyLocalHistogram(smaller_top_features, larger_top_features);
......@@ -390,7 +383,6 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
const int tid = omp_get_thread_num();
const int real_feature_index = this->train_data_->RealFeatureIndex(feature_index);
if (smaller_is_feature_aggregated_[feature_index]) {
SplitInfo smaller_split;
// restore from buffer
smaller_leaf_histogram_array_global_[feature_index].FromMemory(
output_buffer_.data() + smaller_buffer_read_start_pos_[feature_index]);
......@@ -399,22 +391,15 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
smaller_leaf_splits_global_->sum_gradients(), smaller_leaf_splits_global_->sum_hessians(),
smaller_leaf_histogram_array_global_[feature_index].RawData());
// find best threshold
smaller_leaf_histogram_array_global_[feature_index].FindBestThreshold(
smaller_leaf_splits_global_->sum_gradients(),
smaller_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->LeafIndex()),
smaller_leaf_splits_global_->min_constraint(),
smaller_leaf_splits_global_->max_constraint(),
&smaller_split);
smaller_split.feature = real_feature_index;
if (smaller_split > smaller_bests_per_thread[tid] && smaller_node_used_features[feature_index]) {
smaller_bests_per_thread[tid] = smaller_split;
}
this->ComputeBestSplitForFeature(
smaller_leaf_histogram_array_global_.get(), feature_index,
real_feature_index, smaller_node_used_features[feature_index],
GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->leaf_index()),
smaller_leaf_splits_global_.get(), &smaller_bests_per_thread[tid]);
}
if (larger_is_feature_aggregated_[feature_index]) {
SplitInfo larger_split;
// restore from buffer
larger_leaf_histogram_array_global_[feature_index].FromMemory(output_buffer_.data() + larger_buffer_read_start_pos_[feature_index]);
......@@ -422,47 +407,42 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
larger_leaf_splits_global_->sum_gradients(), larger_leaf_splits_global_->sum_hessians(),
larger_leaf_histogram_array_global_[feature_index].RawData());
// find best threshold
larger_leaf_histogram_array_global_[feature_index].FindBestThreshold(
larger_leaf_splits_global_->sum_gradients(),
larger_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(larger_leaf_splits_global_->LeafIndex()),
larger_leaf_splits_global_->min_constraint(),
larger_leaf_splits_global_->max_constraint(),
&larger_split);
larger_split.feature = real_feature_index;
if (larger_split > larger_best_per_thread[tid] && larger_node_used_features[feature_index]) {
larger_best_per_thread[tid] = larger_split;
}
this->ComputeBestSplitForFeature(
larger_leaf_histogram_array_global_.get(), feature_index,
real_feature_index,
larger_node_used_features[feature_index],
GetGlobalDataCountInLeaf(larger_leaf_splits_global_->leaf_index()),
larger_leaf_splits_global_.get(),
&larger_best_per_thread[tid]);
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_bests_per_thread);
int leaf = this->smaller_leaf_splits_->LeafIndex();
int leaf = this->smaller_leaf_splits_->leaf_index();
this->best_split_per_leaf_[leaf] = smaller_bests_per_thread[smaller_best_idx];
if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->LeafIndex() >= 0) {
leaf = this->larger_leaf_splits_->LeafIndex();
if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->leaf_index() >= 0) {
leaf = this->larger_leaf_splits_->leaf_index();
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best_per_thread);
this->best_split_per_leaf_[leaf] = larger_best_per_thread[larger_best_idx];
}
// find local best
SplitInfo smaller_best_split, larger_best_split;
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()];
// find local best split for larger leaf
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
if (this->larger_leaf_splits_->leaf_index() >= 0) {
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()];
}
// sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
// copy back
this->best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best_split;
if (larger_best_split.feature >= 0 && larger_leaf_splits_global_->LeafIndex() >= 0) {
this->best_split_per_leaf_[larger_leaf_splits_global_->LeafIndex()] = larger_best_split;
this->best_split_per_leaf_[smaller_leaf_splits_global_->leaf_index()] = smaller_best_split;
if (larger_best_split.feature >= 0 && larger_leaf_splits_global_->leaf_index() >= 0) {
this->best_split_per_leaf_[larger_leaf_splits_global_->leaf_index()] = larger_best_split;
}
}
......@@ -473,8 +453,6 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf,
// set the global number of data for leaves
global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count;
global_data_count_in_leaf_[*right_leaf] = best_split_info.right_count;
auto p_left = smaller_leaf_splits_global_.get();
auto p_right = larger_leaf_splits_global_.get();
// init the global sumup info
if (best_split_info.left_count < best_split_info.right_count) {
smaller_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(),
......@@ -490,22 +468,6 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf,
larger_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian);
p_left = larger_leaf_splits_global_.get();
p_right = smaller_leaf_splits_global_.get();
}
const int inner_feature_index = this->train_data_->InnerFeatureIndex(best_split_info.feature);
bool is_numerical_split = this->train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin;
p_left->SetValueConstraint(best_split_info.min_constraint, best_split_info.max_constraint);
p_right->SetValueConstraint(best_split_info.min_constraint, best_split_info.max_constraint);
if (is_numerical_split) {
double mid = (best_split_info.left_output + best_split_info.right_output) / 2.0f;
if (best_split_info.monotone_type < 0) {
p_left->SetValueConstraint(mid, best_split_info.max_constraint);
p_right->SetValueConstraint(best_split_info.min_constraint, mid);
} else if (best_split_info.monotone_type > 0) {
p_left->SetValueConstraint(best_split_info.min_constraint, mid);
p_right->SetValueConstraint(mid, best_split_info.max_constraint);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment