Unverified Commit ebf962fc authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

max tree output ( max_delta_step) (#1322)

* first draft

* refine a branching
parent 381bc122
...@@ -245,6 +245,14 @@ Learning Control Parameters ...@@ -245,6 +245,14 @@ Learning Control Parameters
- L2 regularization - L2 regularization
- ``max_delta_step``, default=\ ``0``, type=double, alias=\ ``max_tree_output``, ``max_leaf_output``
- Used to limit the max output of tree leaves
- when <= 0, there is not constraint
- the final max output of leaves is ``learning_rate*max_delta_step``
- ``min_split_gain``, default=\ ``0``, type=double, alias=\ ``min_gain_to_split`` - ``min_split_gain``, default=\ ``0``, type=double, alias=\ ``min_gain_to_split``
- the minimal gain to perform split - the minimal gain to perform split
......
...@@ -203,6 +203,7 @@ struct TreeConfig: public ConfigBase { ...@@ -203,6 +203,7 @@ struct TreeConfig: public ConfigBase {
public: public:
int min_data_in_leaf = 20; int min_data_in_leaf = 20;
double min_sum_hessian_in_leaf = 1e-3f; double min_sum_hessian_in_leaf = 1e-3f;
double max_delta_step = 0.0f;
double lambda_l1 = 0.0f; double lambda_l1 = 0.0f;
double lambda_l2 = 0.0f; double lambda_l2 = 0.0f;
double min_gain_to_split = 0.0f; double min_gain_to_split = 0.0f;
...@@ -446,7 +447,9 @@ struct ParameterAlias { ...@@ -446,7 +447,9 @@ struct ParameterAlias {
{ "nodes", "machines" }, { "nodes", "machines" },
{ "subsample_for_bin", "bin_construct_sample_cnt" }, { "subsample_for_bin", "bin_construct_sample_cnt" },
{ "metric_freq", "output_freq" }, { "metric_freq", "output_freq" },
{ "mc", "monotone_constraints" } { "mc", "monotone_constraints" },
{ "max_tree_output", "max_delta_step" },
{ "max_leaf_output", "max_delta_step" }
}); });
const std::unordered_set<std::string> parameter_set({ const std::unordered_set<std::string> parameter_set({
"config", "config_file", "task", "device", "config", "config_file", "task", "device",
...@@ -479,7 +482,7 @@ struct ParameterAlias { ...@@ -479,7 +482,7 @@ struct ParameterAlias {
"histogram_pool_size", "is_provide_training_metric", "machine_list_filename", "machines", "histogram_pool_size", "is_provide_training_metric", "machine_list_filename", "machines",
"zero_as_missing", "init_score_file", "valid_init_score_file", "is_predict_contrib", "zero_as_missing", "init_score_file", "valid_init_score_file", "is_predict_contrib",
"max_cat_threshold", "cat_smooth", "min_data_per_group", "cat_l2", "max_cat_to_onehot", "max_cat_threshold", "cat_smooth", "min_data_per_group", "cat_l2", "max_cat_to_onehot",
"alpha", "reg_sqrt", "tweedie_variance_power", "monotone_constraints" "alpha", "reg_sqrt", "tweedie_variance_power", "monotone_constraints", "max_delta_step"
}); });
std::unordered_map<std::string, std::string> tmp_map; std::unordered_map<std::string, std::string> tmp_map;
for (const auto& pair : *params) { for (const auto& pair : *params) {
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
namespace LightGBM { namespace LightGBM {
#define kMaxTreeOutput (100)
#define kCategoricalMask (1) #define kCategoricalMask (1)
#define kDefaultLeftMask (2) #define kDefaultLeftMask (2)
...@@ -141,7 +140,6 @@ public: ...@@ -141,7 +140,6 @@ public:
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048) #pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
for (int i = 0; i < num_leaves_; ++i) { for (int i = 0; i < num_leaves_; ++i) {
leaf_value_[i] *= rate; leaf_value_[i] *= rate;
if (leaf_value_[i] > kMaxTreeOutput) { leaf_value_[i] = kMaxTreeOutput; } else if (leaf_value_[i] < -kMaxTreeOutput) { leaf_value_[i] = -kMaxTreeOutput; }
} }
shrinkage_ *= rate; shrinkage_ *= rate;
} }
......
...@@ -405,6 +405,7 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) ...@@ -405,6 +405,7 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
CHECK(lambda_l1 >= 0.0f); CHECK(lambda_l1 >= 0.0f);
GetDouble(params, "lambda_l2", &lambda_l2); GetDouble(params, "lambda_l2", &lambda_l2);
CHECK(lambda_l2 >= 0.0f); CHECK(lambda_l2 >= 0.0f);
GetDouble(params, "max_delta_step", &max_delta_step);
GetDouble(params, "min_gain_to_split", &min_gain_to_split); GetDouble(params, "min_gain_to_split", &min_gain_to_split);
CHECK(min_gain_to_split >= 0.0f); CHECK(min_gain_to_split >= 0.0f);
GetInt(params, "num_leaves", &num_leaves); GetInt(params, "num_leaves", &num_leaves);
......
...@@ -82,7 +82,7 @@ public: ...@@ -82,7 +82,7 @@ public:
is_splittable_ = false; is_splittable_ = false;
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian, double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2); meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step);
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split; double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split;
if (meta_->num_bin > 2 && meta_->missing_type != MissingType::None) { if (meta_->num_bin > 2 && meta_->missing_type != MissingType::None) {
if (meta_->missing_type == MissingType::Zero) { if (meta_->missing_type == MissingType::Zero) {
...@@ -112,7 +112,7 @@ public: ...@@ -112,7 +112,7 @@ public:
data_size_t best_left_count = 0; data_size_t best_left_count = 0;
double best_sum_left_gradient = 0; double best_sum_left_gradient = 0;
double best_sum_left_hessian = 0; double best_sum_left_hessian = 0;
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian, meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2); double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian, meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step);
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split; double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split;
bool is_full_categorical = meta_->missing_type == MissingType::None; bool is_full_categorical = meta_->missing_type == MissingType::None;
...@@ -140,7 +140,7 @@ public: ...@@ -140,7 +140,7 @@ public:
double sum_other_gradient = sum_gradient - data_[t].sum_gradients; double sum_other_gradient = sum_gradient - data_[t].sum_gradients;
// current split gain // current split gain
double current_gain = GetSplitGains(sum_other_gradient, sum_other_hessian, data_[t].sum_gradients, data_[t].sum_hessians + kEpsilon, double current_gain = GetSplitGains(sum_other_gradient, sum_other_hessian, data_[t].sum_gradients, data_[t].sum_hessians + kEpsilon,
meta_->tree_config->lambda_l1, l2, meta_->tree_config->lambda_l1, l2, meta_->tree_config->max_delta_step,
min_constraint, max_constraint, 0); min_constraint, max_constraint, 0);
// gain with split is worse than without split // gain with split is worse than without split
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
...@@ -212,7 +212,7 @@ public: ...@@ -212,7 +212,7 @@ public:
double sum_right_gradient = sum_gradient - sum_left_gradient; double sum_right_gradient = sum_gradient - sum_left_gradient;
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian, double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, l2, meta_->tree_config->lambda_l1, l2, meta_->tree_config->max_delta_step,
min_constraint, max_constraint, 0); min_constraint, max_constraint, 0);
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
is_splittable_ = true; is_splittable_ = true;
...@@ -230,13 +230,15 @@ public: ...@@ -230,13 +230,15 @@ public:
if (is_splittable_) { if (is_splittable_) {
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian, output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian,
meta_->tree_config->lambda_l1, l2, min_constraint, max_constraint); meta_->tree_config->lambda_l1, l2, meta_->tree_config->max_delta_step,
min_constraint, max_constraint);
output->left_count = best_left_count; output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient; output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian - kEpsilon; output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient, output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient,
sum_hessian - best_sum_left_hessian, sum_hessian - best_sum_left_hessian,
meta_->tree_config->lambda_l1, l2, min_constraint, max_constraint); meta_->tree_config->lambda_l1, l2, meta_->tree_config->max_delta_step,
min_constraint, max_constraint);
output->right_count = num_data - best_left_count; output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient; output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon; output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon;
...@@ -289,25 +291,28 @@ public: ...@@ -289,25 +291,28 @@ public:
*/ */
void set_is_splittable(bool val) { is_splittable_ = val; } void set_is_splittable(bool val) { is_splittable_ = val; }
/*! static double ThresholdL1(double s, double l1) {
* \brief Calculate the output of a leaf based on regularized sum_gradients and sum_hessians const double reg_s = std::max(0.0, std::fabs(s) - l1);
* \param sum_gradients return Common::Sign(s) * reg_s;
* \param sum_hessians }
* \return leaf output
*/ static double CalculateSplittedLeafOutput(double sum_gradients, double sum_hessians, double l1, double l2, double max_delta_step) {
static double CalculateSplittedLeafOutput(double sum_gradients, double sum_hessians, double l1, double l2) { double ret = -ThresholdL1(sum_gradients, l1) / (sum_hessians + l2);
const double reg_abs_sum_gradients = std::max(0.0, std::fabs(sum_gradients) - l1); if (max_delta_step <= 0.0f || std::fabs(ret) <= max_delta_step) {
return -(Common::Sign(sum_gradients) * reg_abs_sum_gradients) / (sum_hessians + l2); return ret;
} else {
return Common::Sign(ret) * max_delta_step;
}
} }
private: private:
static double GetSplitGains(double sum_left_gradients, double sum_left_hessians, static double GetSplitGains(double sum_left_gradients, double sum_left_hessians,
double sum_right_gradients, double sum_right_hessians, double sum_right_gradients, double sum_right_hessians,
double l1, double l2, double l1, double l2, double max_delta_step,
double min_constraint, double max_constraint, int8_t monotone_constraint) { double min_constraint, double max_constraint, int8_t monotone_constraint) {
double left_output = CalculateSplittedLeafOutput(sum_left_gradients, sum_left_hessians, l1, l2, min_constraint, max_constraint); double left_output = CalculateSplittedLeafOutput(sum_left_gradients, sum_left_hessians, l1, l2, max_delta_step, min_constraint, max_constraint);
double right_output = CalculateSplittedLeafOutput(sum_right_gradients, sum_right_hessians, l1, l2, min_constraint, max_constraint); double right_output = CalculateSplittedLeafOutput(sum_right_gradients, sum_right_hessians, l1, l2, max_delta_step, min_constraint, max_constraint);
if (((monotone_constraint > 0) && (left_output > right_output)) || if (((monotone_constraint > 0) && (left_output > right_output)) ||
((monotone_constraint < 0) && (left_output < right_output))) { ((monotone_constraint < 0) && (left_output < right_output))) {
return 0; return 0;
...@@ -316,20 +321,15 @@ private: ...@@ -316,20 +321,15 @@ private:
+ GetLeafSplitGainGivenOutput(sum_right_gradients, sum_right_hessians, l1, l2, right_output); + GetLeafSplitGainGivenOutput(sum_right_gradients, sum_right_hessians, l1, l2, right_output);
} }
static double ThresholdL1(double s, double l1) {
const double reg_s = std::max(0.0, std::fabs(s) - l1);
return Common::Sign(s) * reg_s;
}
/*! /*!
* \brief Calculate the output of a leaf based on regularized sum_gradients and sum_hessians * \brief Calculate the output of a leaf based on regularized sum_gradients and sum_hessians
* \param sum_gradients * \param sum_gradients
* \param sum_hessians * \param sum_hessians
* \return leaf output * \return leaf output
*/ */
static double CalculateSplittedLeafOutput(double sum_gradients, double sum_hessians, double l1, double l2, static double CalculateSplittedLeafOutput(double sum_gradients, double sum_hessians, double l1, double l2, double max_delta_step,
double min_constraint, double max_constraint) { double min_constraint, double max_constraint) {
double ret = -ThresholdL1(sum_gradients, l1) / (sum_hessians + l2); double ret = CalculateSplittedLeafOutput(sum_gradients, sum_hessians, l1, l2, max_delta_step);
if (ret < min_constraint) { if (ret < min_constraint) {
ret = min_constraint; ret = min_constraint;
} else if (ret > max_constraint) { } else if (ret > max_constraint) {
...@@ -344,11 +344,9 @@ private: ...@@ -344,11 +344,9 @@ private:
* \param sum_hessians * \param sum_hessians
* \return split gain * \return split gain
*/ */
static double GetLeafSplitGain(double sum_gradients, double sum_hessians, double l1, double l2) { static double GetLeafSplitGain(double sum_gradients, double sum_hessians, double l1, double l2, double max_delta_step) {
double abs_sum_gradients = std::fabs(sum_gradients); double output = CalculateSplittedLeafOutput(sum_gradients, sum_hessians, l1, l2, max_delta_step);
double reg_abs_sum_gradients = std::max(0.0, abs_sum_gradients - l1); return GetLeafSplitGainGivenOutput(sum_gradients, sum_hessians, l1, l2, output);
return (reg_abs_sum_gradients * reg_abs_sum_gradients)
/ (sum_hessians + l2);
} }
static double GetLeafSplitGainGivenOutput(double sum_gradients, double sum_hessians, double l1, double l2, double output) { static double GetLeafSplitGainGivenOutput(double sum_gradients, double sum_hessians, double l1, double l2, double output) {
...@@ -399,7 +397,7 @@ private: ...@@ -399,7 +397,7 @@ private:
double sum_left_gradient = sum_gradient - sum_right_gradient; double sum_left_gradient = sum_gradient - sum_right_gradient;
// current split gain // current split gain
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian, double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step,
min_constraint, max_constraint, meta_->monotone_type); min_constraint, max_constraint, meta_->monotone_type);
// gain with split is worse than without split // gain with split is worse than without split
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
...@@ -459,7 +457,7 @@ private: ...@@ -459,7 +457,7 @@ private:
double sum_right_gradient = sum_gradient - sum_left_gradient; double sum_right_gradient = sum_gradient - sum_left_gradient;
// current split gain // current split gain
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian, double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step,
min_constraint, max_constraint, meta_->monotone_type); min_constraint, max_constraint, meta_->monotone_type);
// gain with split is worse than without split // gain with split is worse than without split
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
...@@ -481,13 +479,15 @@ private: ...@@ -481,13 +479,15 @@ private:
// update split information // update split information
output->threshold = best_threshold; output->threshold = best_threshold;
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian, output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, min_constraint, max_constraint); meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step,
min_constraint, max_constraint);
output->left_count = best_left_count; output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient; output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian - kEpsilon; output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient, output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient,
sum_hessian - best_sum_left_hessian, sum_hessian - best_sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, min_constraint, max_constraint); meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step,
min_constraint, max_constraint);
output->right_count = num_data - best_left_count; output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient; output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon; output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon;
......
...@@ -224,7 +224,7 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* ...@@ -224,7 +224,7 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t*
sum_hess += hessians[idx]; sum_hess += hessians[idx];
} }
double output = FeatureHistogram::CalculateSplittedLeafOutput(sum_grad, sum_hess, double output = FeatureHistogram::CalculateSplittedLeafOutput(sum_grad, sum_hess,
tree_config_->lambda_l1, tree_config_->lambda_l2); tree_config_->lambda_l1, tree_config_->lambda_l2, tree_config_->max_delta_step);
tree->SetLeafOutput(i, output* tree->shrinkage()); tree->SetLeafOutput(i, output* tree->shrinkage());
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment