"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "a2a38b6cd527f8c87d0250e0cbb89b12899c27dd"
Unverified Commit 4278f222 authored by CharlesAuguste's avatar CharlesAuguste Committed by GitHub
Browse files

Pr4 advanced method monotone constraints (#3264)



* No need to pass the tree to all fuctions related to monotone constraints because the pointer is shared.

* Fix OppositeChildShouldBeUpdated numerical split optimisation.

* No need to use constraints when computing the output of the root.

* Refactor existing constraints.

* Add advanced constraints method.

* Update tests.

* Add override.

* linting.

* Add override.

* Simplify condition in LeftRightContainsRelevantInformation.

* Add virtual destructor to FeatureConstraint.

* Remove redundant blank line.

* linting of else.

* Indentation.

* Lint else.

* Replaced non-const reference by pointers.

* Forgotten reference.

* Leverage USE_MC for efficiency.

* Make constraints const again in feature_histogram.hpp.

* Update docs.

* Add "advanced" to the monotone constraints options.

* Update monotone constraints restrictions.

* Fix loop iterator.
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Fix loop iterator.
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Remove superfluous parenthesis.

* Fix loop iterator.
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Fix loop iterator.
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Fix loop iterator.
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Fix loop iterator.
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Fix loop iterator.
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Fix loop iterator.
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Fix loop iterator.
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Fix loop iterator.
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Remove std namespace qualifier.

* Fix unsigned_int size_t comparison.

* Set num_features as int for consistency with the rest of the codebase.

* Make sure constraints exist before recomputing them.

* Initialize previous constraints in UpdateConstraints.

* Update monotone constraints restrictions.

* Refactor UpdateConstraints loop.

* Update src/io/config.cpp
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Delete white spaces.
Co-authored-by: default avatarCharles Auguste <charles.auguste@sig.com>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 3454698e
...@@ -462,7 +462,7 @@ Learning Control Parameters ...@@ -462,7 +462,7 @@ Learning Control Parameters
- you need to specify all features in order. For example, ``mc=-1,0,1`` means decreasing for 1st feature, non-constraint for 2nd feature and increasing for the 3rd feature - you need to specify all features in order. For example, ``mc=-1,0,1`` means decreasing for 1st feature, non-constraint for 2nd feature and increasing for the 3rd feature
- ``monotone_constraints_method`` :raw-html:`<a id="monotone_constraints_method" title="Permalink to this parameter" href="#monotone_constraints_method">&#x1F517;&#xFE0E;</a>`, default = ``basic``, type = enum, options: ``basic``, ``intermediate``, aliases: ``monotone_constraining_method``, ``mc_method`` - ``monotone_constraints_method`` :raw-html:`<a id="monotone_constraints_method" title="Permalink to this parameter" href="#monotone_constraints_method">&#x1F517;&#xFE0E;</a>`, default = ``basic``, type = enum, options: ``basic``, ``intermediate``, ``advanced``, aliases: ``monotone_constraining_method``, ``mc_method``
- used only if ``monotone_constraints`` is set - used only if ``monotone_constraints`` is set
...@@ -472,6 +472,8 @@ Learning Control Parameters ...@@ -472,6 +472,8 @@ Learning Control Parameters
- ``intermediate``, a `more advanced method <https://github.com/microsoft/LightGBM/files/3457826/PR-monotone-constraints-report.pdf>`__, which may slow the library very slightly. However, this method is much less constraining than the basic method and should significantly improve the results - ``intermediate``, a `more advanced method <https://github.com/microsoft/LightGBM/files/3457826/PR-monotone-constraints-report.pdf>`__, which may slow the library very slightly. However, this method is much less constraining than the basic method and should significantly improve the results
- ``advanced``, an `even more advanced method <https://github.com/microsoft/LightGBM/files/3457826/PR-monotone-constraints-report.pdf>`__, which may slow the library. However, this method is even less constraining than the intermediate method and should again significantly improve the results
- ``monotone_penalty`` :raw-html:`<a id="monotone_penalty" title="Permalink to this parameter" href="#monotone_penalty">&#x1F517;&#xFE0E;</a>`, default = ``0.0``, type = double, aliases: ``monotone_splits_penalty``, ``ms_penalty``, ``mc_penalty``, constraints: ``monotone_penalty >= 0.0`` - ``monotone_penalty`` :raw-html:`<a id="monotone_penalty" title="Permalink to this parameter" href="#monotone_penalty">&#x1F517;&#xFE0E;</a>`, default = ``0.0``, type = double, aliases: ``monotone_splits_penalty``, ``ms_penalty``, ``mc_penalty``, constraints: ``monotone_penalty >= 0.0``
- used only if ``monotone_constraints`` is set - used only if ``monotone_constraints`` is set
......
...@@ -443,11 +443,12 @@ struct Config { ...@@ -443,11 +443,12 @@ struct Config {
// type = enum // type = enum
// alias = monotone_constraining_method, mc_method // alias = monotone_constraining_method, mc_method
// options = basic, intermediate // options = basic, intermediate, advanced
// desc = used only if ``monotone_constraints`` is set // desc = used only if ``monotone_constraints`` is set
// desc = monotone constraints method // desc = monotone constraints method
// descl2 = ``basic``, the most basic monotone constraints method. It does not slow the library at all, but over-constrains the predictions // descl2 = ``basic``, the most basic monotone constraints method. It does not slow the library at all, but over-constrains the predictions
// descl2 = ``intermediate``, a `more advanced method <https://github.com/microsoft/LightGBM/files/3457826/PR-monotone-constraints-report.pdf>`__, which may slow the library very slightly. However, this method is much less constraining than the basic method and should significantly improve the results // descl2 = ``intermediate``, a `more advanced method <https://github.com/microsoft/LightGBM/files/3457826/PR-monotone-constraints-report.pdf>`__, which may slow the library very slightly. However, this method is much less constraining than the basic method and should significantly improve the results
// descl2 = ``advanced``, an `even more advanced method <https://github.com/microsoft/LightGBM/files/3457826/PR-monotone-constraints-report.pdf>`__, which may slow the library. However, this method is even less constraining than the intermediate method and should again significantly improve the results
std::string monotone_constraints_method = "basic"; std::string monotone_constraints_method = "basic";
// alias = monotone_splits_penalty, ms_penalty, mc_penalty // alias = monotone_splits_penalty, ms_penalty, mc_penalty
......
...@@ -345,15 +345,15 @@ void Config::CheckParamConflict() { ...@@ -345,15 +345,15 @@ void Config::CheckParamConflict() {
min_data_in_leaf = 2; min_data_in_leaf = 2;
Log::Warning("min_data_in_leaf has been increased to 2 because this is required when path smoothing is active."); Log::Warning("min_data_in_leaf has been increased to 2 because this is required when path smoothing is active.");
} }
if (is_parallel && monotone_constraints_method == std::string("intermediate")) { if (is_parallel && (monotone_constraints_method == std::string("intermediate") || monotone_constraints_method == std::string("advanced"))) {
// In distributed mode, local node doesn't have histograms on all features, cannot perform "intermediate" monotone constraints. // In distributed mode, local node doesn't have histograms on all features, cannot perform "intermediate" monotone constraints.
Log::Warning("Cannot use \"intermediate\" monotone constraints in parallel learning, auto set to \"basic\" method."); Log::Warning("Cannot use \"intermediate\" or \"advanced\" monotone constraints in parallel learning, auto set to \"basic\" method.");
monotone_constraints_method = "basic"; monotone_constraints_method = "basic";
} }
if (feature_fraction_bynode != 1.0 && monotone_constraints_method == std::string("intermediate")) { if (feature_fraction_bynode != 1.0 && (monotone_constraints_method == std::string("intermediate") || monotone_constraints_method == std::string("advanced"))) {
// "intermediate" monotone constraints need to recompute splits. If the features are sampled when computing the // "intermediate" monotone constraints need to recompute splits. If the features are sampled when computing the
// split initially, then the sampling needs to be recorded or done once again, which is currently not supported // split initially, then the sampling needs to be recorded or done once again, which is currently not supported
Log::Warning("Cannot use \"intermediate\" monotone constraints with feature fraction different from 1, auto set monotone constraints to \"basic\" method."); Log::Warning("Cannot use \"intermediate\" or \"advanced\" monotone constraints with feature fraction different from 1, auto set monotone constraints to \"basic\" method.");
monotone_constraints_method = "basic"; monotone_constraints_method = "basic";
} }
if (max_depth > 0 && monotone_penalty >= max_depth) { if (max_depth > 0 && monotone_penalty >= max_depth) {
......
...@@ -84,7 +84,7 @@ class FeatureHistogram { ...@@ -84,7 +84,7 @@ class FeatureHistogram {
void FindBestThreshold(double sum_gradient, double sum_hessian, void FindBestThreshold(double sum_gradient, double sum_hessian,
data_size_t num_data, data_size_t num_data,
const ConstraintEntry& constraints, const FeatureConstraint* constraints,
double parent_output, double parent_output,
SplitInfo* output) { SplitInfo* output) {
output->default_left = true; output->default_left = true;
...@@ -158,7 +158,7 @@ class FeatureHistogram { ...@@ -158,7 +158,7 @@ class FeatureHistogram {
#define TEMPLATE_PREFIX USE_RAND, USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING #define TEMPLATE_PREFIX USE_RAND, USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING
#define LAMBDA_ARGUMENTS \ #define LAMBDA_ARGUMENTS \
double sum_gradient, double sum_hessian, data_size_t num_data, \ double sum_gradient, double sum_hessian, data_size_t num_data, \
const ConstraintEntry &constraints, double parent_output, SplitInfo *output const FeatureConstraint* constraints, double parent_output, SplitInfo *output
#define BEFORE_ARGUMENTS sum_gradient, sum_hessian, parent_output, num_data, output, &rand_threshold #define BEFORE_ARGUMENTS sum_gradient, sum_hessian, parent_output, num_data, output, &rand_threshold
#define FUNC_ARGUMENTS \ #define FUNC_ARGUMENTS \
sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, \ sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, \
...@@ -278,7 +278,7 @@ class FeatureHistogram { ...@@ -278,7 +278,7 @@ class FeatureHistogram {
void FindBestThresholdCategoricalInner(double sum_gradient, void FindBestThresholdCategoricalInner(double sum_gradient,
double sum_hessian, double sum_hessian,
data_size_t num_data, data_size_t num_data,
const ConstraintEntry& constraints, const FeatureConstraint* constraints,
double parent_output, double parent_output,
SplitInfo* output) { SplitInfo* output) {
is_splittable_ = false; is_splittable_ = false;
...@@ -288,6 +288,9 @@ class FeatureHistogram { ...@@ -288,6 +288,9 @@ class FeatureHistogram {
double best_sum_left_gradient = 0; double best_sum_left_gradient = 0;
double best_sum_left_hessian = 0; double best_sum_left_hessian = 0;
double gain_shift; double gain_shift;
if (USE_MC) {
constraints->InitCumulativeConstraints(true);
}
if (USE_SMOOTHING) { if (USE_SMOOTHING) {
gain_shift = GetLeafGainGivenOutput<USE_L1>( gain_shift = GetLeafGainGivenOutput<USE_L1>(
sum_gradient, sum_hessian, meta_->config->lambda_l1, meta_->config->lambda_l2, parent_output); sum_gradient, sum_hessian, meta_->config->lambda_l1, meta_->config->lambda_l2, parent_output);
...@@ -474,14 +477,14 @@ class FeatureHistogram { ...@@ -474,14 +477,14 @@ class FeatureHistogram {
output->left_output = CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>( output->left_output = CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
best_sum_left_gradient, best_sum_left_hessian, best_sum_left_gradient, best_sum_left_hessian,
meta_->config->lambda_l1, l2, meta_->config->max_delta_step, meta_->config->lambda_l1, l2, meta_->config->max_delta_step,
constraints, meta_->config->path_smooth, best_left_count, parent_output); constraints->LeftToBasicConstraint(), meta_->config->path_smooth, best_left_count, parent_output);
output->left_count = best_left_count; output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient; output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian - kEpsilon; output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>( output->right_output = CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
sum_gradient - best_sum_left_gradient, sum_gradient - best_sum_left_gradient,
sum_hessian - best_sum_left_hessian, meta_->config->lambda_l1, l2, sum_hessian - best_sum_left_hessian, meta_->config->lambda_l1, l2,
meta_->config->max_delta_step, constraints, meta_->config->path_smooth, meta_->config->max_delta_step, constraints->RightToBasicConstraint(), meta_->config->path_smooth,
num_data - best_left_count, parent_output); num_data - best_left_count, parent_output);
output->right_count = num_data - best_left_count; output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient; output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
...@@ -763,7 +766,7 @@ class FeatureHistogram { ...@@ -763,7 +766,7 @@ class FeatureHistogram {
template <bool USE_MC, bool USE_L1, bool USE_MAX_OUTPUT, bool USE_SMOOTHING> template <bool USE_MC, bool USE_L1, bool USE_MAX_OUTPUT, bool USE_SMOOTHING>
static double CalculateSplittedLeafOutput( static double CalculateSplittedLeafOutput(
double sum_gradients, double sum_hessians, double l1, double l2, double sum_gradients, double sum_hessians, double l1, double l2,
double max_delta_step, const ConstraintEntry& constraints, double max_delta_step, const BasicConstraint& constraints,
double smoothing, data_size_t num_data, double parent_output) { double smoothing, data_size_t num_data, double parent_output) {
double ret = CalculateSplittedLeafOutput<USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>( double ret = CalculateSplittedLeafOutput<USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
sum_gradients, sum_hessians, l1, l2, max_delta_step, smoothing, num_data, parent_output); sum_gradients, sum_hessians, l1, l2, max_delta_step, smoothing, num_data, parent_output);
...@@ -784,7 +787,7 @@ class FeatureHistogram { ...@@ -784,7 +787,7 @@ class FeatureHistogram {
double sum_right_gradients, double sum_right_gradients,
double sum_right_hessians, double l1, double l2, double sum_right_hessians, double l1, double l2,
double max_delta_step, double max_delta_step,
const ConstraintEntry& constraints, const FeatureConstraint* constraints,
int8_t monotone_constraint, int8_t monotone_constraint,
double smoothing, double smoothing,
data_size_t left_count, data_size_t left_count,
...@@ -803,11 +806,11 @@ class FeatureHistogram { ...@@ -803,11 +806,11 @@ class FeatureHistogram {
double left_output = double left_output =
CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>( CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
sum_left_gradients, sum_left_hessians, l1, l2, max_delta_step, sum_left_gradients, sum_left_hessians, l1, l2, max_delta_step,
constraints, smoothing, left_count, parent_output); constraints->LeftToBasicConstraint(), smoothing, left_count, parent_output);
double right_output = double right_output =
CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>( CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
sum_right_gradients, sum_right_hessians, l1, l2, max_delta_step, sum_right_gradients, sum_right_hessians, l1, l2, max_delta_step,
constraints, smoothing, right_count, parent_output); constraints->RightToBasicConstraint(), smoothing, right_count, parent_output);
if (((monotone_constraint > 0) && (left_output > right_output)) || if (((monotone_constraint > 0) && (left_output > right_output)) ||
((monotone_constraint < 0) && (left_output < right_output))) { ((monotone_constraint < 0) && (left_output < right_output))) {
return 0; return 0;
...@@ -854,7 +857,7 @@ class FeatureHistogram { ...@@ -854,7 +857,7 @@ class FeatureHistogram {
bool REVERSE, bool SKIP_DEFAULT_BIN, bool NA_AS_MISSING> bool REVERSE, bool SKIP_DEFAULT_BIN, bool NA_AS_MISSING>
void FindBestThresholdSequentially(double sum_gradient, double sum_hessian, void FindBestThresholdSequentially(double sum_gradient, double sum_hessian,
data_size_t num_data, data_size_t num_data,
const ConstraintEntry& constraints, const FeatureConstraint* constraints,
double min_gain_shift, SplitInfo* output, double min_gain_shift, SplitInfo* output,
int rand_threshold, double parent_output) { int rand_threshold, double parent_output) {
const int8_t offset = meta_->offset; const int8_t offset = meta_->offset;
...@@ -864,6 +867,16 @@ class FeatureHistogram { ...@@ -864,6 +867,16 @@ class FeatureHistogram {
data_size_t best_left_count = 0; data_size_t best_left_count = 0;
uint32_t best_threshold = static_cast<uint32_t>(meta_->num_bin); uint32_t best_threshold = static_cast<uint32_t>(meta_->num_bin);
const double cnt_factor = num_data / sum_hessian; const double cnt_factor = num_data / sum_hessian;
BasicConstraint best_right_constraints;
BasicConstraint best_left_constraints;
bool constraint_update_necessary =
USE_MC && constraints->ConstraintDifferentDependingOnThreshold();
if (USE_MC) {
constraints->InitCumulativeConstraints(REVERSE);
}
if (REVERSE) { if (REVERSE) {
double sum_right_gradient = 0.0f; double sum_right_gradient = 0.0f;
double sum_right_hessian = kEpsilon; double sum_right_hessian = kEpsilon;
...@@ -910,6 +923,11 @@ class FeatureHistogram { ...@@ -910,6 +923,11 @@ class FeatureHistogram {
continue; continue;
} }
} }
if (USE_MC && constraint_update_necessary) {
constraints->Update(t + offset);
}
// current split gain // current split gain
double current_gain = GetSplitGains<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>( double current_gain = GetSplitGains<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_left_gradient, sum_left_hessian, sum_right_gradient,
...@@ -932,6 +950,10 @@ class FeatureHistogram { ...@@ -932,6 +950,10 @@ class FeatureHistogram {
// left is <= threshold, right is > threshold. so this is t-1 // left is <= threshold, right is > threshold. so this is t-1
best_threshold = static_cast<uint32_t>(t - 1 + offset); best_threshold = static_cast<uint32_t>(t - 1 + offset);
best_gain = current_gain; best_gain = current_gain;
if (USE_MC) {
best_right_constraints = constraints->RightToBasicConstraint();
best_left_constraints = constraints->LeftToBasicConstraint();
}
} }
} }
} else { } else {
...@@ -1016,6 +1038,10 @@ class FeatureHistogram { ...@@ -1016,6 +1038,10 @@ class FeatureHistogram {
best_sum_left_hessian = sum_left_hessian; best_sum_left_hessian = sum_left_hessian;
best_threshold = static_cast<uint32_t>(t + offset); best_threshold = static_cast<uint32_t>(t + offset);
best_gain = current_gain; best_gain = current_gain;
if (USE_MC) {
best_right_constraints = constraints->RightToBasicConstraint();
best_left_constraints = constraints->LeftToBasicConstraint();
}
} }
} }
} }
...@@ -1027,7 +1053,7 @@ class FeatureHistogram { ...@@ -1027,7 +1053,7 @@ class FeatureHistogram {
CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>( CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
best_sum_left_gradient, best_sum_left_hessian, best_sum_left_gradient, best_sum_left_hessian,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->config->max_delta_step, constraints, meta_->config->path_smooth, meta_->config->max_delta_step, best_left_constraints, meta_->config->path_smooth,
best_left_count, parent_output); best_left_count, parent_output);
output->left_count = best_left_count; output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient; output->left_sum_gradient = best_sum_left_gradient;
...@@ -1037,7 +1063,7 @@ class FeatureHistogram { ...@@ -1037,7 +1063,7 @@ class FeatureHistogram {
sum_gradient - best_sum_left_gradient, sum_gradient - best_sum_left_gradient,
sum_hessian - best_sum_left_hessian, meta_->config->lambda_l1, sum_hessian - best_sum_left_hessian, meta_->config->lambda_l1,
meta_->config->lambda_l2, meta_->config->max_delta_step, meta_->config->lambda_l2, meta_->config->max_delta_step,
constraints, meta_->config->path_smooth, num_data - best_left_count, best_right_constraints, meta_->config->path_smooth, num_data - best_left_count,
parent_output); parent_output);
output->right_count = num_data - best_left_count; output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient; output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
...@@ -1053,7 +1079,7 @@ class FeatureHistogram { ...@@ -1053,7 +1079,7 @@ class FeatureHistogram {
hist_t* data_; hist_t* data_;
bool is_splittable_ = true; bool is_splittable_ = true;
std::function<void(double, double, data_size_t, const ConstraintEntry&, std::function<void(double, double, data_size_t, const FeatureConstraint*,
double, SplitInfo*)> double, SplitInfo*)>
find_best_threshold_fun_; find_best_threshold_fun_;
}; };
......
...@@ -16,22 +16,60 @@ ...@@ -16,22 +16,60 @@
namespace LightGBM { namespace LightGBM {
struct ConstraintEntry { class LeafConstraintsBase;
struct BasicConstraint {
double min = -std::numeric_limits<double>::max(); double min = -std::numeric_limits<double>::max();
double max = std::numeric_limits<double>::max(); double max = std::numeric_limits<double>::max();
ConstraintEntry() {} BasicConstraint(double min, double max) : min(min), max(max) {}
void Reset() { BasicConstraint() = default;
};
struct FeatureConstraint {
virtual void InitCumulativeConstraints(bool) const {}
virtual void Update(int) const {}
virtual BasicConstraint LeftToBasicConstraint() const = 0;
virtual BasicConstraint RightToBasicConstraint() const = 0;
virtual bool ConstraintDifferentDependingOnThreshold() const = 0;
virtual ~FeatureConstraint() {}
};
struct ConstraintEntry {
virtual void Reset() = 0;
virtual void UpdateMin(double new_min) = 0;
virtual void UpdateMax(double new_max) = 0;
virtual bool UpdateMinAndReturnBoolIfChanged(double new_min) = 0;
virtual bool UpdateMaxAndReturnBoolIfChanged(double new_max) = 0;
virtual ConstraintEntry *clone() const = 0;
virtual void RecomputeConstraintsIfNeeded(LeafConstraintsBase *, int, int,
uint32_t) {}
virtual FeatureConstraint *GetFeatureConstraint(int feature_index) = 0;
};
// used by both BasicLeafConstraints and IntermediateLeafConstraints
struct BasicConstraintEntry : ConstraintEntry,
FeatureConstraint,
BasicConstraint {
bool ConstraintDifferentDependingOnThreshold() const final { return false; }
BasicConstraintEntry *clone() const final {
return new BasicConstraintEntry(*this);
};
void Reset() final {
min = -std::numeric_limits<double>::max(); min = -std::numeric_limits<double>::max();
max = std::numeric_limits<double>::max(); max = std::numeric_limits<double>::max();
} }
void UpdateMin(double new_min) { min = std::max(new_min, min); } void UpdateMin(double new_min) final { min = std::max(new_min, min); }
void UpdateMax(double new_max) { max = std::min(new_max, max); } void UpdateMax(double new_max) final { max = std::min(new_max, max); }
bool UpdateMinAndReturnBoolIfChanged(double new_min) { bool UpdateMinAndReturnBoolIfChanged(double new_min) final {
if (new_min > min) { if (new_min > min) {
min = new_min; min = new_min;
return true; return true;
...@@ -39,29 +77,278 @@ struct ConstraintEntry { ...@@ -39,29 +77,278 @@ struct ConstraintEntry {
return false; return false;
} }
bool UpdateMaxAndReturnBoolIfChanged(double new_max) { bool UpdateMaxAndReturnBoolIfChanged(double new_max) final {
if (new_max < max) { if (new_max < max) {
max = new_max; max = new_max;
return true; return true;
} }
return false; return false;
} }
BasicConstraint LeftToBasicConstraint() const final { return *this; }
BasicConstraint RightToBasicConstraint() const final { return *this; }
FeatureConstraint *GetFeatureConstraint(int) final { return this; }
};
struct FeatureMinOrMaxConstraints {
std::vector<double> constraints;
// the constraint number i is valid on the slice
// [thresholds[i]:threshold[i+1])
// if threshold[i+1] does not exist, then it is valid for thresholds following
// threshold[i]
std::vector<uint32_t> thresholds;
FeatureMinOrMaxConstraints() {
constraints.reserve(32);
thresholds.reserve(32);
}
size_t Size() const { return thresholds.size(); }
explicit FeatureMinOrMaxConstraints(double extremum) {
constraints.reserve(32);
thresholds.reserve(32);
constraints.push_back(extremum);
thresholds.push_back(0);
}
void Reset(double extremum) {
constraints.resize(1);
constraints[0] = extremum;
thresholds.resize(1);
thresholds[0] = 0;
}
void UpdateMin(double min) {
for (size_t j = 0; j < constraints.size(); ++j) {
if (min > constraints[j]) {
constraints[j] = min;
}
}
}
void UpdateMax(double max) {
for (size_t j = 0; j < constraints.size(); ++j) {
if (max < constraints[j]) {
constraints[j] = max;
}
}
}
};
struct CumulativeFeatureConstraint {
std::vector<uint32_t> thresholds_min_constraints;
std::vector<uint32_t> thresholds_max_constraints;
std::vector<double> cumulative_min_constraints_left_to_right;
std::vector<double> cumulative_min_constraints_right_to_left;
std::vector<double> cumulative_max_constraints_left_to_right;
std::vector<double> cumulative_max_constraints_right_to_left;
size_t index_min_constraints_left_to_right;
size_t index_min_constraints_right_to_left;
size_t index_max_constraints_left_to_right;
size_t index_max_constraints_right_to_left;
static void CumulativeExtremum(
const double &(*extremum_function)(const double &, const double &),
bool is_direction_from_left_to_right,
std::vector<double>* cumulative_extremum) {
if (cumulative_extremum->size() == 1) {
return;
}
#ifdef DEBUG
CHECK_NE(cumulative_extremum->size(), 0);
#endif
size_t n_exts = cumulative_extremum->size();
int step = is_direction_from_left_to_right ? 1 : -1;
size_t start = is_direction_from_left_to_right ? 0 : n_exts - 1;
size_t end = is_direction_from_left_to_right ? n_exts - 1 : 0;
for (auto i = start; i != end; i = i + step) {
(*cumulative_extremum)[i + step] = extremum_function(
(*cumulative_extremum)[i + step], (*cumulative_extremum)[i]);
}
}
CumulativeFeatureConstraint() = default;
CumulativeFeatureConstraint(FeatureMinOrMaxConstraints min_constraints,
FeatureMinOrMaxConstraints max_constraints,
bool REVERSE) {
thresholds_min_constraints = min_constraints.thresholds;
thresholds_max_constraints = max_constraints.thresholds;
cumulative_min_constraints_left_to_right = min_constraints.constraints;
cumulative_min_constraints_right_to_left = min_constraints.constraints;
cumulative_max_constraints_left_to_right = max_constraints.constraints;
cumulative_max_constraints_right_to_left = max_constraints.constraints;
const double &(*min)(const double &, const double &) = std::min<double>;
const double &(*max)(const double &, const double &) = std::max<double>;
CumulativeExtremum(max, true, &cumulative_min_constraints_left_to_right);
CumulativeExtremum(max, false, &cumulative_min_constraints_right_to_left);
CumulativeExtremum(min, true, &cumulative_max_constraints_left_to_right);
CumulativeExtremum(min, false, &cumulative_max_constraints_right_to_left);
if (REVERSE) {
index_min_constraints_left_to_right =
thresholds_min_constraints.size() - 1;
index_min_constraints_right_to_left =
thresholds_min_constraints.size() - 1;
index_max_constraints_left_to_right =
thresholds_max_constraints.size() - 1;
index_max_constraints_right_to_left =
thresholds_max_constraints.size() - 1;
} else {
index_min_constraints_left_to_right = 0;
index_min_constraints_right_to_left = 0;
index_max_constraints_left_to_right = 0;
index_max_constraints_right_to_left = 0;
}
}
void Update(int threshold) {
while (
static_cast<int>(
thresholds_min_constraints[index_min_constraints_left_to_right]) >
threshold - 1) {
index_min_constraints_left_to_right -= 1;
}
while (
static_cast<int>(
thresholds_min_constraints[index_min_constraints_right_to_left]) >
threshold) {
index_min_constraints_right_to_left -= 1;
}
while (
static_cast<int>(
thresholds_max_constraints[index_max_constraints_left_to_right]) >
threshold - 1) {
index_max_constraints_left_to_right -= 1;
}
while (
static_cast<int>(
thresholds_max_constraints[index_max_constraints_right_to_left]) >
threshold) {
index_max_constraints_right_to_left -= 1;
}
}
double GetRightMin() const {
return cumulative_min_constraints_right_to_left
[index_min_constraints_right_to_left];
}
double GetRightMax() const {
return cumulative_max_constraints_right_to_left
[index_max_constraints_right_to_left];
}
double GetLeftMin() const {
return cumulative_min_constraints_left_to_right
[index_min_constraints_left_to_right];
}
double GetLeftMax() const {
return cumulative_max_constraints_left_to_right
[index_max_constraints_left_to_right];
}
};
struct AdvancedFeatureConstraints : FeatureConstraint {
FeatureMinOrMaxConstraints min_constraints;
FeatureMinOrMaxConstraints max_constraints;
mutable CumulativeFeatureConstraint cumulative_feature_constraint;
bool min_constraints_to_be_recomputed = false;
bool max_constraints_to_be_recomputed = false;
void InitCumulativeConstraints(bool REVERSE) const final {
cumulative_feature_constraint =
CumulativeFeatureConstraint(min_constraints, max_constraints, REVERSE);
}
void Update(int threshold) const final {
cumulative_feature_constraint.Update(threshold);
}
FeatureMinOrMaxConstraints &GetMinConstraints() { return min_constraints; }
FeatureMinOrMaxConstraints &GetMaxConstraints() { return max_constraints; }
bool ConstraintDifferentDependingOnThreshold() const final {
return min_constraints.Size() > 1 || max_constraints.Size() > 1;
}
BasicConstraint RightToBasicConstraint() const final {
return BasicConstraint(cumulative_feature_constraint.GetRightMin(),
cumulative_feature_constraint.GetRightMax());
}
BasicConstraint LeftToBasicConstraint() const final {
return BasicConstraint(cumulative_feature_constraint.GetLeftMin(),
cumulative_feature_constraint.GetLeftMax());
}
void Reset() {
min_constraints.Reset(-std::numeric_limits<double>::max());
max_constraints.Reset(std::numeric_limits<double>::max());
}
void UpdateMax(double new_max, bool trigger_a_recompute) {
if (trigger_a_recompute) {
max_constraints_to_be_recomputed = true;
}
max_constraints.UpdateMax(new_max);
}
bool FeatureMaxConstraintsToBeUpdated() {
return max_constraints_to_be_recomputed;
}
bool FeatureMinConstraintsToBeUpdated() {
return min_constraints_to_be_recomputed;
}
void ResetUpdates() {
min_constraints_to_be_recomputed = false;
max_constraints_to_be_recomputed = false;
}
void UpdateMin(double new_min, bool trigger_a_recompute) {
if (trigger_a_recompute) {
min_constraints_to_be_recomputed = true;
}
min_constraints.UpdateMin(new_min);
}
}; };
class LeafConstraintsBase { class LeafConstraintsBase {
public: public:
virtual ~LeafConstraintsBase() {} virtual ~LeafConstraintsBase() {}
virtual const ConstraintEntry& Get(int leaf_idx) const = 0; virtual const ConstraintEntry* Get(int leaf_idx) = 0;
virtual FeatureConstraint* GetFeatureConstraint(int leaf_idx, int feature_index) = 0;
virtual void Reset() = 0; virtual void Reset() = 0;
virtual void BeforeSplit(const Tree* tree, int leaf, int new_leaf, virtual void BeforeSplit(int leaf, int new_leaf,
int8_t monotone_type) = 0; int8_t monotone_type) = 0;
virtual std::vector<int> Update( virtual std::vector<int> Update(
const Tree* tree, bool is_numerical_split, bool is_numerical_split,
int leaf, int new_leaf, int8_t monotone_type, double right_output, int leaf, int new_leaf, int8_t monotone_type, double right_output,
double left_output, int split_feature, const SplitInfo& split_info, double left_output, int split_feature, const SplitInfo& split_info,
const std::vector<SplitInfo>& best_split_per_leaf) = 0; const std::vector<SplitInfo>& best_split_per_leaf) = 0;
inline static LeafConstraintsBase* Create(const Config* config, int num_leaves); virtual void GoUpToFindConstrainingLeaves(
int, int,
std::vector<int>*,
std::vector<uint32_t>*,
std::vector<bool>*,
FeatureMinOrMaxConstraints*, bool ,
uint32_t, uint32_t, uint32_t) {}
virtual void RecomputeConstraintsIfNeeded(
LeafConstraintsBase *constraints_,
int feature_for_constraint, int leaf_idx, uint32_t it_end) = 0;
inline static LeafConstraintsBase* Create(const Config* config, int num_leaves, int num_features);
double ComputeMonotoneSplitGainPenalty(int leaf_index, double penalization) { double ComputeMonotoneSplitGainPenalty(int leaf_index, double penalization) {
int depth = tree_->leaf_depth(leaf_index); int depth = tree_->leaf_depth(leaf_index);
...@@ -78,48 +365,148 @@ class LeafConstraintsBase { ...@@ -78,48 +365,148 @@ class LeafConstraintsBase {
tree_ = tree; tree_ = tree;
} }
private: protected:
const Tree* tree_; const Tree* tree_;
}; };
// used by AdvancedLeafConstraints
struct AdvancedConstraintEntry : ConstraintEntry {
std::vector<AdvancedFeatureConstraints> constraints;
AdvancedConstraintEntry *clone() const final {
return new AdvancedConstraintEntry(*this);
};
void RecomputeConstraintsIfNeeded(LeafConstraintsBase *constraints_,
int feature_for_constraint, int leaf_idx,
uint32_t it_end) final {
if (constraints[feature_for_constraint]
.FeatureMinConstraintsToBeUpdated() ||
constraints[feature_for_constraint]
.FeatureMaxConstraintsToBeUpdated()) {
FeatureMinOrMaxConstraints &constraints_to_be_updated =
constraints[feature_for_constraint].FeatureMinConstraintsToBeUpdated()
? constraints[feature_for_constraint].GetMinConstraints()
: constraints[feature_for_constraint].GetMaxConstraints();
constraints_to_be_updated.Reset(
constraints[feature_for_constraint].FeatureMinConstraintsToBeUpdated()
? -std::numeric_limits<double>::max()
: std::numeric_limits<double>::max());
std::vector<int> features_of_splits_going_up_from_original_leaf =
std::vector<int>();
std::vector<uint32_t> thresholds_of_splits_going_up_from_original_leaf =
std::vector<uint32_t>();
std::vector<bool> was_original_leaf_right_child_of_split =
std::vector<bool>();
constraints_->GoUpToFindConstrainingLeaves(
feature_for_constraint, leaf_idx,
&features_of_splits_going_up_from_original_leaf,
&thresholds_of_splits_going_up_from_original_leaf,
&was_original_leaf_right_child_of_split, &constraints_to_be_updated,
constraints[feature_for_constraint]
.FeatureMinConstraintsToBeUpdated(),
0, it_end, it_end);
constraints[feature_for_constraint].ResetUpdates();
}
}
// for each feature, an array of constraints needs to be stored
explicit AdvancedConstraintEntry(int num_features) {
constraints.resize(num_features);
}
void Reset() final {
for (size_t i = 0; i < constraints.size(); ++i) {
constraints[i].Reset();
}
}
void UpdateMin(double new_min) final {
for (size_t i = 0; i < constraints.size(); ++i) {
constraints[i].UpdateMin(new_min, false);
}
}
void UpdateMax(double new_max) final {
for (size_t i = 0; i < constraints.size(); ++i) {
constraints[i].UpdateMax(new_max, false);
}
}
bool UpdateMinAndReturnBoolIfChanged(double new_min) final {
for (size_t i = 0; i < constraints.size(); ++i) {
constraints[i].UpdateMin(new_min, true);
}
// even if nothing changed, this could have been unconstrained so it needs
// to be recomputed from the beginning
return true;
}
bool UpdateMaxAndReturnBoolIfChanged(double new_max) final {
for (size_t i = 0; i < constraints.size(); ++i) {
constraints[i].UpdateMax(new_max, true);
}
// even if nothing changed, this could have been unconstrained so it needs
// to be recomputed from the beginning
return true;
}
FeatureConstraint *GetFeatureConstraint(int feature_index) final {
return &constraints[feature_index];
}
};
class BasicLeafConstraints : public LeafConstraintsBase { class BasicLeafConstraints : public LeafConstraintsBase {
public: public:
explicit BasicLeafConstraints(int num_leaves) : num_leaves_(num_leaves) { explicit BasicLeafConstraints(int num_leaves) : num_leaves_(num_leaves) {
entries_.resize(num_leaves_); for (int i = 0; i < num_leaves; ++i) {
entries_.push_back(new BasicConstraintEntry());
}
} }
void Reset() override { void Reset() override {
for (auto& entry : entries_) { for (auto entry : entries_) {
entry.Reset(); entry->Reset();
} }
} }
void BeforeSplit(const Tree*, int, int, int8_t) override {} void RecomputeConstraintsIfNeeded(
LeafConstraintsBase* constraints_,
int feature_for_constraint, int leaf_idx, uint32_t it_end) override {
entries_[~leaf_idx]->RecomputeConstraintsIfNeeded(constraints_, feature_for_constraint, leaf_idx, it_end);
}
std::vector<int> Update(const Tree*, void BeforeSplit(int, int, int8_t) override {}
bool is_numerical_split, int leaf, int new_leaf,
std::vector<int> Update(bool is_numerical_split, int leaf, int new_leaf,
int8_t monotone_type, double right_output, int8_t monotone_type, double right_output,
double left_output, int, const SplitInfo& , double left_output, int, const SplitInfo& ,
const std::vector<SplitInfo>&) override { const std::vector<SplitInfo>&) override {
entries_[new_leaf] = entries_[leaf]; entries_[new_leaf] = entries_[leaf]->clone();
if (is_numerical_split) { if (is_numerical_split) {
double mid = (left_output + right_output) / 2.0f; double mid = (left_output + right_output) / 2.0f;
if (monotone_type < 0) { if (monotone_type < 0) {
entries_[leaf].UpdateMin(mid); entries_[leaf]->UpdateMin(mid);
entries_[new_leaf].UpdateMax(mid); entries_[new_leaf]->UpdateMax(mid);
} else if (monotone_type > 0) { } else if (monotone_type > 0) {
entries_[leaf].UpdateMax(mid); entries_[leaf]->UpdateMax(mid);
entries_[new_leaf].UpdateMin(mid); entries_[new_leaf]->UpdateMin(mid);
} }
} }
return std::vector<int>(); return std::vector<int>();
} }
const ConstraintEntry& Get(int leaf_idx) const override { return entries_[leaf_idx]; } const ConstraintEntry* Get(int leaf_idx) override { return entries_[leaf_idx]; }
FeatureConstraint* GetFeatureConstraint(int leaf_idx, int feature_index) final {
return entries_[leaf_idx]->GetFeatureConstraint(feature_index);
}
protected: protected:
int num_leaves_; int num_leaves_;
std::vector<ConstraintEntry> entries_; std::vector<ConstraintEntry*> entries_;
}; };
class IntermediateLeafConstraints : public BasicLeafConstraints { class IntermediateLeafConstraints : public BasicLeafConstraints {
...@@ -138,7 +525,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -138,7 +525,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
leaves_to_update_.clear(); leaves_to_update_.clear();
} }
void BeforeSplit(const Tree* tree, int leaf, int new_leaf, void BeforeSplit(int leaf, int new_leaf,
int8_t monotone_type) override { int8_t monotone_type) override {
if (monotone_type != 0 || leaf_is_in_monotone_subtree_[leaf]) { if (monotone_type != 0 || leaf_is_in_monotone_subtree_[leaf]) {
leaf_is_in_monotone_subtree_[leaf] = true; leaf_is_in_monotone_subtree_[leaf] = true;
...@@ -148,36 +535,36 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -148,36 +535,36 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
CHECK_GE(new_leaf - 1, 0); CHECK_GE(new_leaf - 1, 0);
CHECK_LT(static_cast<size_t>(new_leaf - 1), node_parent_.size()); CHECK_LT(static_cast<size_t>(new_leaf - 1), node_parent_.size());
#endif #endif
node_parent_[new_leaf - 1] = tree->leaf_parent(leaf); node_parent_[new_leaf - 1] = tree_->leaf_parent(leaf);
} }
void UpdateConstraintsWithOutputs(bool is_numerical_split, int leaf, void UpdateConstraintsWithOutputs(bool is_numerical_split, int leaf,
int new_leaf, int8_t monotone_type, int new_leaf, int8_t monotone_type,
double right_output, double left_output) { double right_output, double left_output) {
entries_[new_leaf] = entries_[leaf]; entries_[new_leaf] = entries_[leaf]->clone();
if (is_numerical_split) { if (is_numerical_split) {
if (monotone_type < 0) { if (monotone_type < 0) {
entries_[leaf].UpdateMin(right_output); entries_[leaf]->UpdateMin(right_output);
entries_[new_leaf].UpdateMax(left_output); entries_[new_leaf]->UpdateMax(left_output);
} else if (monotone_type > 0) { } else if (monotone_type > 0) {
entries_[leaf].UpdateMax(right_output); entries_[leaf]->UpdateMax(right_output);
entries_[new_leaf].UpdateMin(left_output); entries_[new_leaf]->UpdateMin(left_output);
} }
} }
} }
std::vector<int> Update(const Tree* tree, bool is_numerical_split, int leaf, std::vector<int> Update(bool is_numerical_split, int leaf,
int new_leaf, int8_t monotone_type, int new_leaf, int8_t monotone_type,
double right_output, double left_output, double right_output, double left_output,
int split_feature, const SplitInfo& split_info, int split_feature, const SplitInfo& split_info,
const std::vector<SplitInfo>& best_split_per_leaf) override { const std::vector<SplitInfo>& best_split_per_leaf) final {
leaves_to_update_.clear(); leaves_to_update_.clear();
if (leaf_is_in_monotone_subtree_[leaf]) { if (leaf_is_in_monotone_subtree_[leaf]) {
UpdateConstraintsWithOutputs(is_numerical_split, leaf, new_leaf, UpdateConstraintsWithOutputs(is_numerical_split, leaf, new_leaf,
monotone_type, right_output, left_output); monotone_type, right_output, left_output);
// Initialize variables to store information while going up the tree // Initialize variables to store information while going up the tree
int depth = tree->leaf_depth(new_leaf) - 1; int depth = tree_->leaf_depth(new_leaf) - 1;
std::vector<int> features_of_splits_going_up_from_original_leaf; std::vector<int> features_of_splits_going_up_from_original_leaf;
std::vector<uint32_t> thresholds_of_splits_going_up_from_original_leaf; std::vector<uint32_t> thresholds_of_splits_going_up_from_original_leaf;
...@@ -187,7 +574,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -187,7 +574,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
thresholds_of_splits_going_up_from_original_leaf.reserve(depth); thresholds_of_splits_going_up_from_original_leaf.reserve(depth);
was_original_leaf_right_child_of_split.reserve(depth); was_original_leaf_right_child_of_split.reserve(depth);
GoUpToFindLeavesToUpdate(tree, tree->leaf_parent(new_leaf), GoUpToFindLeavesToUpdate(tree_->leaf_parent(new_leaf),
&features_of_splits_going_up_from_original_leaf, &features_of_splits_going_up_from_original_leaf,
&thresholds_of_splits_going_up_from_original_leaf, &thresholds_of_splits_going_up_from_original_leaf,
&was_original_leaf_right_child_of_split, &was_original_leaf_right_child_of_split,
...@@ -203,8 +590,6 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -203,8 +590,6 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
int inner_feature, int inner_feature,
const std::vector<bool>& was_original_leaf_right_child_of_split, const std::vector<bool>& was_original_leaf_right_child_of_split,
bool is_in_right_child) { bool is_in_right_child) {
bool opposite_child_should_be_updated = true;
// if the split is categorical, it is not handled by this optimisation, // if the split is categorical, it is not handled by this optimisation,
// so the code will have to go down in the other child subtree to see if // so the code will have to go down in the other child subtree to see if
// there are leaves to update // there are leaves to update
...@@ -221,18 +606,19 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -221,18 +606,19 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
inner_feature && inner_feature &&
(was_original_leaf_right_child_of_split[split_idx] == (was_original_leaf_right_child_of_split[split_idx] ==
is_in_right_child)) { is_in_right_child)) {
opposite_child_should_be_updated = false; return false;
break;
} }
} }
return true;
} else {
return false;
} }
return opposite_child_should_be_updated;
} }
// Recursive function that goes up the tree, and then down to find leaves that // Recursive function that goes up the tree, and then down to find leaves that
// have constraints to be updated // have constraints to be updated
void GoUpToFindLeavesToUpdate( void GoUpToFindLeavesToUpdate(
const Tree* tree, int node_idx, int node_idx,
std::vector<int>* features_of_splits_going_up_from_original_leaf, std::vector<int>* features_of_splits_going_up_from_original_leaf,
std::vector<uint32_t>* thresholds_of_splits_going_up_from_original_leaf, std::vector<uint32_t>* thresholds_of_splits_going_up_from_original_leaf,
std::vector<bool>* was_original_leaf_right_child_of_split, std::vector<bool>* was_original_leaf_right_child_of_split,
...@@ -245,11 +631,11 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -245,11 +631,11 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
int parent_idx = node_parent_[node_idx]; int parent_idx = node_parent_[node_idx];
// if not at the root // if not at the root
if (parent_idx != -1) { if (parent_idx != -1) {
int inner_feature = tree->split_feature_inner(parent_idx); int inner_feature = tree_->split_feature_inner(parent_idx);
int feature = tree->split_feature(parent_idx); int feature = tree_->split_feature(parent_idx);
int8_t monotone_type = config_->monotone_constraints[feature]; int8_t monotone_type = config_->monotone_constraints[feature];
bool is_in_right_child = tree->right_child(parent_idx) == node_idx; bool is_in_right_child = tree_->right_child(parent_idx) == node_idx;
bool is_split_numerical = tree->IsNumericalSplit(node_idx); bool is_split_numerical = tree_->IsNumericalSplit(parent_idx);
// this is just an optimisation not to waste time going down in subtrees // this is just an optimisation not to waste time going down in subtrees
// where there won't be any leaf to update // where there won't be any leaf to update
...@@ -264,8 +650,8 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -264,8 +650,8 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
if (monotone_type != 0) { if (monotone_type != 0) {
// these variables correspond to the current split we encounter going // these variables correspond to the current split we encounter going
// up the tree // up the tree
int left_child_idx = tree->left_child(parent_idx); int left_child_idx = tree_->left_child(parent_idx);
int right_child_idx = tree->right_child(parent_idx); int right_child_idx = tree_->right_child(parent_idx);
bool left_child_is_curr_idx = (left_child_idx == node_idx); bool left_child_is_curr_idx = (left_child_idx == node_idx);
int opposite_child_idx = int opposite_child_idx =
(left_child_is_curr_idx) ? right_child_idx : left_child_idx; (left_child_is_curr_idx) ? right_child_idx : left_child_idx;
...@@ -277,7 +663,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -277,7 +663,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
// so the code needs to go down in the the opposite child // so the code needs to go down in the the opposite child
// to see which leaves' constraints need to be updated // to see which leaves' constraints need to be updated
GoDownToFindLeavesToUpdate( GoDownToFindLeavesToUpdate(
tree, opposite_child_idx, opposite_child_idx,
*features_of_splits_going_up_from_original_leaf, *features_of_splits_going_up_from_original_leaf,
*thresholds_of_splits_going_up_from_original_leaf, *thresholds_of_splits_going_up_from_original_leaf,
*was_original_leaf_right_child_of_split, *was_original_leaf_right_child_of_split,
...@@ -290,16 +676,16 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -290,16 +676,16 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
// is actually contiguous to the original 2 leaves and should be updated // is actually contiguous to the original 2 leaves and should be updated
// so the variables associated with the split need to be recorded // so the variables associated with the split need to be recorded
was_original_leaf_right_child_of_split->push_back( was_original_leaf_right_child_of_split->push_back(
tree->right_child(parent_idx) == node_idx); tree_->right_child(parent_idx) == node_idx);
thresholds_of_splits_going_up_from_original_leaf->push_back( thresholds_of_splits_going_up_from_original_leaf->push_back(
tree->threshold_in_bin(parent_idx)); tree_->threshold_in_bin(parent_idx));
features_of_splits_going_up_from_original_leaf->push_back( features_of_splits_going_up_from_original_leaf->push_back(
tree->split_feature_inner(parent_idx)); tree_->split_feature_inner(parent_idx));
} }
// since current node is not the root, keep going up // since current node is not the root, keep going up
GoUpToFindLeavesToUpdate( GoUpToFindLeavesToUpdate(
tree, parent_idx, features_of_splits_going_up_from_original_leaf, parent_idx, features_of_splits_going_up_from_original_leaf,
thresholds_of_splits_going_up_from_original_leaf, thresholds_of_splits_going_up_from_original_leaf,
was_original_leaf_right_child_of_split, split_feature, split_info, was_original_leaf_right_child_of_split, split_feature, split_info,
split_threshold, best_split_per_leaf); split_threshold, best_split_per_leaf);
...@@ -307,7 +693,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -307,7 +693,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
} }
void GoDownToFindLeavesToUpdate( void GoDownToFindLeavesToUpdate(
const Tree* tree, int node_idx, int node_idx,
const std::vector<int>& features_of_splits_going_up_from_original_leaf, const std::vector<int>& features_of_splits_going_up_from_original_leaf,
const std::vector<uint32_t>& const std::vector<uint32_t>&
thresholds_of_splits_going_up_from_original_leaf, thresholds_of_splits_going_up_from_original_leaf,
...@@ -345,18 +731,18 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -345,18 +731,18 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
#ifdef DEBUG #ifdef DEBUG
if (update_max_constraints) { if (update_max_constraints) {
CHECK_GE(min_max_constraints.first, tree->LeafOutput(leaf_idx)); CHECK_GE(min_max_constraints.first, tree_->LeafOutput(leaf_idx));
} else { } else {
CHECK_LE(min_max_constraints.second, tree->LeafOutput(leaf_idx)); CHECK_LE(min_max_constraints.second, tree_->LeafOutput(leaf_idx));
} }
#endif #endif
// depending on which split made the current leaf and the original leaves contiguous, // depending on which split made the current leaf and the original leaves contiguous,
// either the min constraint or the max constraint of the current leaf need to be updated // either the min constraint or the max constraint of the current leaf need to be updated
if (!update_max_constraints) { if (!update_max_constraints) {
something_changed = entries_[leaf_idx].UpdateMinAndReturnBoolIfChanged( something_changed = entries_[leaf_idx]->UpdateMinAndReturnBoolIfChanged(
min_max_constraints.second); min_max_constraints.second);
} else { } else {
something_changed = entries_[leaf_idx].UpdateMaxAndReturnBoolIfChanged( something_changed = entries_[leaf_idx]->UpdateMaxAndReturnBoolIfChanged(
min_max_constraints.first); min_max_constraints.first);
} }
// If constraints were not updated, then there is no need to update the leaf // If constraints were not updated, then there is no need to update the leaf
...@@ -368,12 +754,12 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -368,12 +754,12 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
} else { // if node } else { // if node
// check if the children are contiguous with the original leaf // check if the children are contiguous with the original leaf
std::pair<bool, bool> keep_going_left_right = ShouldKeepGoingLeftRight( std::pair<bool, bool> keep_going_left_right = ShouldKeepGoingLeftRight(
tree, node_idx, features_of_splits_going_up_from_original_leaf, node_idx, features_of_splits_going_up_from_original_leaf,
thresholds_of_splits_going_up_from_original_leaf, thresholds_of_splits_going_up_from_original_leaf,
was_original_leaf_right_child_of_split); was_original_leaf_right_child_of_split);
int inner_feature = tree->split_feature_inner(node_idx); int inner_feature = tree_->split_feature_inner(node_idx);
uint32_t threshold = tree->threshold_in_bin(node_idx); uint32_t threshold = tree_->threshold_in_bin(node_idx);
bool is_split_numerical = tree->IsNumericalSplit(node_idx); bool is_split_numerical = tree_->IsNumericalSplit(node_idx);
bool use_left_leaf_for_update_right = true; bool use_left_leaf_for_update_right = true;
bool use_right_leaf_for_update_left = true; bool use_right_leaf_for_update_left = true;
// if the split is on the same feature (categorical variables not supported) // if the split is on the same feature (categorical variables not supported)
...@@ -392,7 +778,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -392,7 +778,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
// go down left // go down left
if (keep_going_left_right.first) { if (keep_going_left_right.first) {
GoDownToFindLeavesToUpdate( GoDownToFindLeavesToUpdate(
tree, tree->left_child(node_idx), tree_->left_child(node_idx),
features_of_splits_going_up_from_original_leaf, features_of_splits_going_up_from_original_leaf,
thresholds_of_splits_going_up_from_original_leaf, thresholds_of_splits_going_up_from_original_leaf,
was_original_leaf_right_child_of_split, update_max_constraints, was_original_leaf_right_child_of_split, update_max_constraints,
...@@ -403,7 +789,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -403,7 +789,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
// go down right // go down right
if (keep_going_left_right.second) { if (keep_going_left_right.second) {
GoDownToFindLeavesToUpdate( GoDownToFindLeavesToUpdate(
tree, tree->right_child(node_idx), tree_->right_child(node_idx),
features_of_splits_going_up_from_original_leaf, features_of_splits_going_up_from_original_leaf,
thresholds_of_splits_going_up_from_original_leaf, thresholds_of_splits_going_up_from_original_leaf,
was_original_leaf_right_child_of_split, update_max_constraints, was_original_leaf_right_child_of_split, update_max_constraints,
...@@ -415,14 +801,14 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -415,14 +801,14 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
} }
std::pair<bool, bool> ShouldKeepGoingLeftRight( std::pair<bool, bool> ShouldKeepGoingLeftRight(
const Tree* tree, int node_idx, int node_idx,
const std::vector<int>& features_of_splits_going_up_from_original_leaf, const std::vector<int>& features_of_splits_going_up_from_original_leaf,
const std::vector<uint32_t>& const std::vector<uint32_t>&
thresholds_of_splits_going_up_from_original_leaf, thresholds_of_splits_going_up_from_original_leaf,
const std::vector<bool>& was_original_leaf_right_child_of_split) { const std::vector<bool>& was_original_leaf_right_child_of_split) {
int inner_feature = tree->split_feature_inner(node_idx); int inner_feature = tree_->split_feature_inner(node_idx);
uint32_t threshold = tree->threshold_in_bin(node_idx); uint32_t threshold = tree_->threshold_in_bin(node_idx);
bool is_split_numerical = tree->IsNumericalSplit(node_idx); bool is_split_numerical = tree_->IsNumericalSplit(node_idx);
bool keep_going_right = true; bool keep_going_right = true;
bool keep_going_left = true; bool keep_going_left = true;
...@@ -456,7 +842,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -456,7 +842,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
return std::pair<bool, bool>(keep_going_left, keep_going_right); return std::pair<bool, bool>(keep_going_left, keep_going_right);
} }
private: protected:
const Config* config_; const Config* config_;
std::vector<int> leaves_to_update_; std::vector<int> leaves_to_update_;
// add parent node information // add parent node information
...@@ -465,11 +851,330 @@ class IntermediateLeafConstraints : public BasicLeafConstraints { ...@@ -465,11 +851,330 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
std::vector<bool> leaf_is_in_monotone_subtree_; std::vector<bool> leaf_is_in_monotone_subtree_;
}; };
class AdvancedLeafConstraints : public IntermediateLeafConstraints {
public:
AdvancedLeafConstraints(const Config *config, int num_leaves,
int num_features)
: IntermediateLeafConstraints(config, num_leaves) {
for (int i = 0; i < num_leaves; ++i) {
entries_[i] = new AdvancedConstraintEntry(num_features);
}
}
// at any point in time, for an index i, the constraint constraint[i] has to
// be valid on [threshold[i]: threshold[i + 1]) (or [threshold[i]: +inf) if i
// is the last index of the array)
void UpdateConstraints(FeatureMinOrMaxConstraints* feature_constraint,
double extremum, uint32_t it_start, uint32_t it_end,
bool use_max_operator, uint32_t last_threshold) {
bool start_done = false;
bool end_done = false;
// previous constraint have to be tracked
// for example when adding a constraints cstr2 on thresholds [1:2),
// on an existing constraints cstr1 on thresholds [0, +inf),
// the thresholds and constraints must become
// [0, 1, 2] and [cstr1, cstr2, cstr1]
// so since we loop through thresholds only once,
// the previous constraint that still applies needs to be recorded
double previous_constraint = use_max_operator
? -std::numeric_limits<double>::max()
: std::numeric_limits<double>::max();
double current_constraint;
for (size_t i = 0; i < feature_constraint->thresholds.size(); ++i) {
current_constraint = feature_constraint->constraints[i];
// easy case when the thresholds match
if (feature_constraint->thresholds[i] == it_start) {
feature_constraint->constraints[i] =
(use_max_operator)
? std::max(extremum, feature_constraint->constraints[i])
: std::min(extremum, feature_constraint->constraints[i]);
start_done = true;
}
if (feature_constraint->thresholds[i] > it_start) {
// existing constraint is updated if there is a need for it
if (feature_constraint->thresholds[i] < it_end) {
feature_constraint->constraints[i] =
(use_max_operator)
? std::max(extremum, feature_constraint->constraints[i])
: std::min(extremum, feature_constraint->constraints[i]);
}
// when thresholds don't match, a new threshold
// and a new constraint may need to be inserted
if (!start_done) {
start_done = true;
if ((use_max_operator && extremum > previous_constraint) ||
(!use_max_operator && extremum < previous_constraint)) {
feature_constraint->constraints.insert(
feature_constraint->constraints.begin() + i, extremum);
feature_constraint->thresholds.insert(
feature_constraint->thresholds.begin() + i, it_start);
++i;
}
}
}
// easy case when the end thresholds match
if (feature_constraint->thresholds[i] == it_end) {
end_done = true;
break;
}
// if they don't then, the previous constraint needs to be added back
// where the current one ends
if (feature_constraint->thresholds[i] > it_end) {
if (i != 0 &&
previous_constraint != feature_constraint->constraints[i - 1]) {
feature_constraint->constraints.insert(
feature_constraint->constraints.begin() + i, previous_constraint);
feature_constraint->thresholds.insert(
feature_constraint->thresholds.begin() + i, it_end);
}
end_done = true;
break;
}
// If 2 successive constraints are the same then the second one may as
// well be deleted
if (i != 0 && feature_constraint->constraints[i] ==
feature_constraint->constraints[i - 1]) {
feature_constraint->constraints.erase(
feature_constraint->constraints.begin() + i);
feature_constraint->thresholds.erase(
feature_constraint->thresholds.begin() + i);
previous_constraint = current_constraint;
--i;
}
previous_constraint = current_constraint;
}
// if the loop didn't get to an index greater than it_start, it needs to be
// added at the end
if (!start_done) {
if ((use_max_operator &&
extremum > feature_constraint->constraints.back()) ||
(!use_max_operator &&
extremum < feature_constraint->constraints.back())) {
feature_constraint->constraints.push_back(extremum);
feature_constraint->thresholds.push_back(it_start);
} else {
end_done = true;
}
}
// if we didn't get to an index after it_end, then the previous constraint
// needs to be set back, unless it_end goes up to the last bin of the feature
if (!end_done && it_end != last_threshold &&
previous_constraint != feature_constraint->constraints.back()) {
feature_constraint->constraints.push_back(previous_constraint);
feature_constraint->thresholds.push_back(it_end);
}
}
// this function is called only when computing constraints when the monotone
// precise mode is set to true
// it makes sure that it is worth it to visit a branch, as it could
// not contain any relevant constraint (for example if the a branch
// with bigger values is also constraining the original leaf, then
// it is useless to visit the branch with smaller values)
std::pair<bool, bool>
LeftRightContainsRelevantInformation(bool min_constraints_to_be_updated,
int feature,
bool split_feature_is_inner_feature) {
if (split_feature_is_inner_feature) {
return std::pair<bool, bool>(true, true);
}
int8_t monotone_type = config_->monotone_constraints[feature];
if (monotone_type == 0) {
return std::pair<bool, bool>(true, true);
}
if ((monotone_type == -1 && min_constraints_to_be_updated) ||
(monotone_type == 1 && !min_constraints_to_be_updated)) {
return std::pair<bool, bool>(true, false);
} else {
// Same as
// if ((monotone_type == 1 && min_constraints_to_be_updated) ||
// (monotone_type == -1 && !min_constraints_to_be_updated))
return std::pair<bool, bool>(false, true);
}
}
// this function goes down in a subtree to find the
// constraints that would apply on the original leaf
void GoDownToFindConstrainingLeaves(
int feature_for_constraint, int root_monotone_feature, int node_idx,
bool min_constraints_to_be_updated, uint32_t it_start, uint32_t it_end,
const std::vector<int> &features_of_splits_going_up_from_original_leaf,
const std::vector<uint32_t> &
thresholds_of_splits_going_up_from_original_leaf,
const std::vector<bool> &was_original_leaf_right_child_of_split,
FeatureMinOrMaxConstraints* feature_constraint, uint32_t last_threshold) {
double extremum;
// if leaf, then constraints need to be updated according to its value
if (node_idx < 0) {
extremum = tree_->LeafOutput(~node_idx);
#ifdef DEBUG
CHECK(it_start < it_end);
#endif
UpdateConstraints(feature_constraint, extremum, it_start, it_end,
min_constraints_to_be_updated, last_threshold);
} else { // if node, keep going down the tree
// check if the children are contiguous to the original leaf and therefore
// potentially constraining
std::pair<bool, bool> keep_going_left_right = ShouldKeepGoingLeftRight(
node_idx, features_of_splits_going_up_from_original_leaf,
thresholds_of_splits_going_up_from_original_leaf,
was_original_leaf_right_child_of_split);
int inner_feature = tree_->split_feature_inner(node_idx);
int feature = tree_->split_feature(node_idx);
uint32_t threshold = tree_->threshold_in_bin(node_idx);
bool split_feature_is_inner_feature =
(inner_feature == feature_for_constraint);
bool split_feature_is_monotone_feature =
(root_monotone_feature == feature_for_constraint);
// make sure that both children contain values that could
// potentially help determine the true constraints for the original leaf
std::pair<bool, bool> left_right_contain_relevant_information =
LeftRightContainsRelevantInformation(
min_constraints_to_be_updated, feature,
split_feature_is_inner_feature &&
!split_feature_is_monotone_feature);
// if both children are contiguous to the original leaf
// but one contains values greater than the other
// then no need to go down in both
if (keep_going_left_right.first &&
(left_right_contain_relevant_information.first ||
!keep_going_left_right.second)) {
// update thresholds based on going left
uint32_t new_it_end = split_feature_is_inner_feature
? std::min(threshold + 1, it_end)
: it_end;
GoDownToFindConstrainingLeaves(
feature_for_constraint, root_monotone_feature,
tree_->left_child(node_idx), min_constraints_to_be_updated,
it_start, new_it_end,
features_of_splits_going_up_from_original_leaf,
thresholds_of_splits_going_up_from_original_leaf,
was_original_leaf_right_child_of_split, feature_constraint,
last_threshold);
}
if (keep_going_left_right.second &&
(left_right_contain_relevant_information.second ||
!keep_going_left_right.first)) {
// update thresholds based on going right
uint32_t new_it_start = split_feature_is_inner_feature
? std::max(threshold + 1, it_start)
: it_start;
GoDownToFindConstrainingLeaves(
feature_for_constraint, root_monotone_feature,
tree_->right_child(node_idx), min_constraints_to_be_updated,
new_it_start, it_end,
features_of_splits_going_up_from_original_leaf,
thresholds_of_splits_going_up_from_original_leaf,
was_original_leaf_right_child_of_split, feature_constraint,
last_threshold);
}
}
}
// this function is only used if the monotone precise mode is enabled
// it recursively goes up the tree then down to find leaf that
// are constraining the current leaf
void GoUpToFindConstrainingLeaves(
int feature_for_constraint, int node_idx,
std::vector<int>* features_of_splits_going_up_from_original_leaf,
std::vector<uint32_t>* thresholds_of_splits_going_up_from_original_leaf,
std::vector<bool>* was_original_leaf_right_child_of_split,
FeatureMinOrMaxConstraints* feature_constraint,
bool min_constraints_to_be_updated, uint32_t it_start, uint32_t it_end,
uint32_t last_threshold) final {
int parent_idx =
(node_idx < 0) ? tree_->leaf_parent(~node_idx) : node_parent_[node_idx];
// if not at the root
if (parent_idx != -1) {
int inner_feature = tree_->split_feature_inner(parent_idx);
int feature = tree_->split_feature(parent_idx);
int8_t monotone_type = config_->monotone_constraints[feature];
bool is_in_right_child = tree_->right_child(parent_idx) == node_idx;
bool is_split_numerical = tree_->IsNumericalSplit(parent_idx);
uint32_t threshold = tree_->threshold_in_bin(parent_idx);
// by going up, more information about the position of the
// original leaf are gathered so the starting and ending
// thresholds can be updated, which will save some time later
if ((feature_for_constraint == inner_feature) && is_split_numerical) {
if (is_in_right_child) {
it_start = std::max(threshold, it_start);
} else {
it_end = std::min(threshold + 1, it_end);
}
#ifdef DEBUG
CHECK(it_start < it_end);
#endif
}
// this is just an optimisation not to waste time going down in subtrees
// where there won't be any new constraining leaf
bool opposite_child_necessary_to_update_constraints =
OppositeChildShouldBeUpdated(
is_split_numerical,
*features_of_splits_going_up_from_original_leaf, inner_feature,
*was_original_leaf_right_child_of_split, is_in_right_child);
if (opposite_child_necessary_to_update_constraints) {
// if there is no monotone constraint on a split,
// then there is no relationship between its left and right leaves'
// values
if (monotone_type != 0) {
int left_child_idx = tree_->left_child(parent_idx);
int right_child_idx = tree_->right_child(parent_idx);
bool left_child_is_curr_idx = (left_child_idx == node_idx);
bool update_min_constraints_in_curr_child_leaf =
(monotone_type < 0) ? left_child_is_curr_idx
: !left_child_is_curr_idx;
if (update_min_constraints_in_curr_child_leaf ==
min_constraints_to_be_updated) {
int opposite_child_idx =
(left_child_is_curr_idx) ? right_child_idx : left_child_idx;
// go down in the opposite branch to find potential
// constraining leaves
GoDownToFindConstrainingLeaves(
feature_for_constraint, inner_feature, opposite_child_idx,
min_constraints_to_be_updated, it_start, it_end,
*features_of_splits_going_up_from_original_leaf,
*thresholds_of_splits_going_up_from_original_leaf,
*was_original_leaf_right_child_of_split, feature_constraint,
last_threshold);
}
}
// if opposite_child_should_be_updated, then it means the path to come
// up there was relevant,
// i.e. that it will be helpful going down to determine which leaf
// is actually contiguous to the original leaf and constraining
// so the variables associated with the split need to be recorded
was_original_leaf_right_child_of_split->push_back(is_in_right_child);
thresholds_of_splits_going_up_from_original_leaf->push_back(threshold);
features_of_splits_going_up_from_original_leaf->push_back(inner_feature);
}
// since current node is not the root, keep going up
if (parent_idx != 0) {
GoUpToFindConstrainingLeaves(
feature_for_constraint, parent_idx,
features_of_splits_going_up_from_original_leaf,
thresholds_of_splits_going_up_from_original_leaf,
was_original_leaf_right_child_of_split, feature_constraint,
min_constraints_to_be_updated, it_start, it_end, last_threshold);
}
}
}
};
LeafConstraintsBase* LeafConstraintsBase::Create(const Config* config, LeafConstraintsBase* LeafConstraintsBase::Create(const Config* config,
int num_leaves) { int num_leaves, int num_features) {
if (config->monotone_constraints_method == "intermediate") { if (config->monotone_constraints_method == "intermediate") {
return new IntermediateLeafConstraints(config, num_leaves); return new IntermediateLeafConstraints(config, num_leaves);
} }
if (config->monotone_constraints_method == "advanced") {
return new AdvancedLeafConstraints(config, num_leaves, num_features);
}
return new BasicLeafConstraints(num_leaves); return new BasicLeafConstraints(num_leaves);
} }
......
...@@ -46,7 +46,7 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian ...@@ -46,7 +46,7 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian
// push split information for all leaves // push split information for all leaves
best_split_per_leaf_.resize(config_->num_leaves); best_split_per_leaf_.resize(config_->num_leaves);
constraints_.reset(LeafConstraintsBase::Create(config_, config_->num_leaves)); constraints_.reset(LeafConstraintsBase::Create(config_, config_->num_leaves, train_data_->num_features()));
// initialize splits for leaf // initialize splits for leaf
smaller_leaf_splits_.reset(new LeafSplits(train_data_->num_data())); smaller_leaf_splits_.reset(new LeafSplits(train_data_->num_data()));
...@@ -146,7 +146,7 @@ void SerialTreeLearner::ResetConfig(const Config* config) { ...@@ -146,7 +146,7 @@ void SerialTreeLearner::ResetConfig(const Config* config) {
} }
cegb_->Init(); cegb_->Init();
} }
constraints_.reset(LeafConstraintsBase::Create(config_, config_->num_leaves)); constraints_.reset(LeafConstraintsBase::Create(config_, config_->num_leaves, train_data_->num_features()));
} }
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) { Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
...@@ -561,7 +561,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf, ...@@ -561,7 +561,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
auto next_leaf_id = tree->NextLeafId(); auto next_leaf_id = tree->NextLeafId();
// update before tree split // update before tree split
constraints_->BeforeSplit(tree, best_leaf, next_leaf_id, constraints_->BeforeSplit(best_leaf, next_leaf_id,
best_split_info.monotone_type); best_split_info.monotone_type);
bool is_numerical_split = bool is_numerical_split =
...@@ -657,7 +657,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf, ...@@ -657,7 +657,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
best_split_info.left_output); best_split_info.left_output);
} }
auto leaves_need_update = constraints_->Update( auto leaves_need_update = constraints_->Update(
tree, is_numerical_split, *left_leaf, *right_leaf, is_numerical_split, *left_leaf, *right_leaf,
best_split_info.monotone_type, best_split_info.right_output, best_split_info.monotone_type, best_split_info.right_output,
best_split_info.left_output, inner_feature_index, best_split_info, best_split_info.left_output, inner_feature_index, best_split_info,
best_split_per_leaf_); best_split_per_leaf_);
...@@ -711,20 +711,29 @@ void SerialTreeLearner::ComputeBestSplitForFeature( ...@@ -711,20 +711,29 @@ void SerialTreeLearner::ComputeBestSplitForFeature(
FeatureHistogram* histogram_array_, int feature_index, int real_fidx, FeatureHistogram* histogram_array_, int feature_index, int real_fidx,
bool is_feature_used, int num_data, const LeafSplits* leaf_splits, bool is_feature_used, int num_data, const LeafSplits* leaf_splits,
SplitInfo* best_split) { SplitInfo* best_split) {
bool is_feature_numerical = train_data_->FeatureBinMapper(feature_index)
->bin_type() == BinType::NumericalBin;
if (is_feature_numerical & !config_->monotone_constraints.empty()) {
constraints_->RecomputeConstraintsIfNeeded(
constraints_.get(), feature_index, ~(leaf_splits->leaf_index()),
train_data_->FeatureNumBin(feature_index));
}
SplitInfo new_split; SplitInfo new_split;
double parent_output; double parent_output;
if (leaf_splits->leaf_index() == 0) { if (leaf_splits->leaf_index() == 0) {
// for root leaf the "parent" output is its own output because we don't apply any smoothing to the root // for root leaf the "parent" output is its own output because we don't apply any smoothing to the root
parent_output = FeatureHistogram::CalculateSplittedLeafOutput<true, true, true, false>( parent_output = FeatureHistogram::CalculateSplittedLeafOutput<false, true, true, false>(
leaf_splits->sum_gradients(), leaf_splits->sum_hessians(), config_->lambda_l1, leaf_splits->sum_gradients(), leaf_splits->sum_hessians(), config_->lambda_l1,
config_->lambda_l2, config_->max_delta_step, constraints_->Get(leaf_splits->leaf_index()), config_->lambda_l2, config_->max_delta_step, BasicConstraint(),
config_->path_smooth, static_cast<data_size_t>(num_data), 0); config_->path_smooth, static_cast<data_size_t>(num_data), 0);
} else { } else {
parent_output = leaf_splits->weight(); parent_output = leaf_splits->weight();
} }
histogram_array_[feature_index].FindBestThreshold( histogram_array_[feature_index].FindBestThreshold(
leaf_splits->sum_gradients(), leaf_splits->sum_hessians(), num_data, leaf_splits->sum_gradients(), leaf_splits->sum_hessians(), num_data,
constraints_->Get(leaf_splits->leaf_index()), parent_output, &new_split); constraints_->GetFeatureConstraint(leaf_splits->leaf_index(), feature_index), parent_output, &new_split);
new_split.feature = real_fidx; new_split.feature = real_fidx;
if (cegb_ != nullptr) { if (cegb_ != nullptr) {
new_split.gain -= new_split.gain -=
......
...@@ -1247,7 +1247,7 @@ class TestEngine(unittest.TestCase): ...@@ -1247,7 +1247,7 @@ class TestEngine(unittest.TestCase):
for test_with_categorical_variable in [True, False]: for test_with_categorical_variable in [True, False]:
trainset = self.generate_trainset_for_monotone_constraints_tests(test_with_categorical_variable) trainset = self.generate_trainset_for_monotone_constraints_tests(test_with_categorical_variable)
for monotone_constraints_method in ["basic", "intermediate"]: for monotone_constraints_method in ["basic", "intermediate", "advanced"]:
params = { params = {
'min_data': 20, 'min_data': 20,
'num_leaves': 20, 'num_leaves': 20,
...@@ -1281,7 +1281,7 @@ class TestEngine(unittest.TestCase): ...@@ -1281,7 +1281,7 @@ class TestEngine(unittest.TestCase):
monotone_constraints = [1, -1, 0] monotone_constraints = [1, -1, 0]
penalization_parameter = 2.0 penalization_parameter = 2.0
trainset = self.generate_trainset_for_monotone_constraints_tests(x3_to_category=False) trainset = self.generate_trainset_for_monotone_constraints_tests(x3_to_category=False)
for monotone_constraints_method in ["basic", "intermediate"]: for monotone_constraints_method in ["basic", "intermediate", "advanced"]:
params = { params = {
'max_depth': max_depth, 'max_depth': max_depth,
'monotone_constraints': monotone_constraints, 'monotone_constraints': monotone_constraints,
...@@ -1320,7 +1320,7 @@ class TestEngine(unittest.TestCase): ...@@ -1320,7 +1320,7 @@ class TestEngine(unittest.TestCase):
unconstrained_model_predictions = unconstrained_model.\ unconstrained_model_predictions = unconstrained_model.\
predict(x3_negatively_correlated_with_y.reshape(-1, 1)) predict(x3_negatively_correlated_with_y.reshape(-1, 1))
for monotone_constraints_method in ["basic", "intermediate"]: for monotone_constraints_method in ["basic", "intermediate", "advanced"]:
params_constrained_model["monotone_constraints_method"] = monotone_constraints_method params_constrained_model["monotone_constraints_method"] = monotone_constraints_method
# The penalization is so high that the first 2 features should not be used here # The penalization is so high that the first 2 features should not be used here
constrained_model = lgb.train(params_constrained_model, trainset_constrained_model, 10) constrained_model = lgb.train(params_constrained_model, trainset_constrained_model, 10)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment