Unverified Commit 446b8b6c authored by Belinda Trotta's avatar Belinda Trotta Committed by GitHub
Browse files

Extremely randomized trees (#2671)

* Add extra-trees functionality.

* Remove unnecessary code.

* Update docs.

* Use template for FindBestThresholdSequence.

* Use separate random seed. Fix bug.
parent da811d46
...@@ -77,3 +77,5 @@ Deal with Over-fitting ...@@ -77,3 +77,5 @@ Deal with Over-fitting
- Try ``lambda_l1``, ``lambda_l2`` and ``min_gain_to_split`` for regularization - Try ``lambda_l1``, ``lambda_l2`` and ``min_gain_to_split`` for regularization
- Try ``max_depth`` to avoid growing deep tree - Try ``max_depth`` to avoid growing deep tree
- Try ``extra_trees``
...@@ -308,6 +308,18 @@ Learning Control Parameters ...@@ -308,6 +308,18 @@ Learning Control Parameters
- random seed for ``feature_fraction`` - random seed for ``feature_fraction``
- ``extra_trees`` :raw-html:`<a id="extra_trees" title="Permalink to this parameter" href="#extra_trees">&#x1F517;&#xFE0E;</a>`, default = ``false``, type = bool
- use extremely randomized trees
- if set to ``true``, when evaluating node splits LightGBM will check only one randomly-chosen threshold for each feature
- can be used to deal with over-fitting
- ``extra_seed`` :raw-html:`<a id="extra_seed" title="Permalink to this parameter" href="#extra_seed">&#x1F517;&#xFE0E;</a>`, default = ``6``, type = int
- random seed for selecting thresholds when ``extra_trees`` is true
- ``early_stopping_round`` :raw-html:`<a id="early_stopping_round" title="Permalink to this parameter" href="#early_stopping_round">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int, aliases: ``early_stopping_rounds``, ``early_stopping``, ``n_iter_no_change`` - ``early_stopping_round`` :raw-html:`<a id="early_stopping_round" title="Permalink to this parameter" href="#early_stopping_round">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int, aliases: ``early_stopping_rounds``, ``early_stopping``, ``n_iter_no_change``
- will stop training if one metric of one validation data doesn't improve in last ``early_stopping_round`` rounds - will stop training if one metric of one validation data doesn't improve in last ``early_stopping_round`` rounds
......
...@@ -307,6 +307,14 @@ struct Config { ...@@ -307,6 +307,14 @@ struct Config {
// desc = random seed for ``feature_fraction`` // desc = random seed for ``feature_fraction``
int feature_fraction_seed = 2; int feature_fraction_seed = 2;
// desc = use extremely randomized trees
// desc = if set to ``true``, when evaluating node splits LightGBM will check only one randomly-chosen threshold for each feature
// desc = can be used to deal with over-fitting
bool extra_trees = false;
// desc = random seed for selecting thresholds when ``extra_trees`` is true
int extra_seed = 6;
// alias = early_stopping_rounds, early_stopping, n_iter_no_change // alias = early_stopping_rounds, early_stopping, n_iter_no_change
// desc = will stop training if one metric of one validation data doesn't improve in last ``early_stopping_round`` rounds // desc = will stop training if one metric of one validation data doesn't improve in last ``early_stopping_round`` rounds
// desc = ``<= 0`` means disable // desc = ``<= 0`` means disable
......
...@@ -193,6 +193,7 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -193,6 +193,7 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) {
drop_seed = static_cast<int>(rand.NextShort(0, int_max)); drop_seed = static_cast<int>(rand.NextShort(0, int_max));
feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max)); feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
objective_seed = static_cast<int>(rand.NextShort(0, int_max)); objective_seed = static_cast<int>(rand.NextShort(0, int_max));
extra_seed = static_cast<int>(rand.NextShort(0, int_max));
} }
GetTaskType(params, &task); GetTaskType(params, &task);
......
...@@ -198,6 +198,8 @@ const std::unordered_set<std::string>& Config::parameter_set() { ...@@ -198,6 +198,8 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"feature_fraction", "feature_fraction",
"feature_fraction_bynode", "feature_fraction_bynode",
"feature_fraction_seed", "feature_fraction_seed",
"extra_trees",
"extra_seed",
"early_stopping_round", "early_stopping_round",
"first_metric_only", "first_metric_only",
"max_delta_step", "max_delta_step",
...@@ -353,6 +355,10 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -353,6 +355,10 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetInt(params, "feature_fraction_seed", &feature_fraction_seed); GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
GetBool(params, "extra_trees", &extra_trees);
GetInt(params, "extra_seed", &extra_seed);
GetInt(params, "early_stopping_round", &early_stopping_round); GetInt(params, "early_stopping_round", &early_stopping_round);
GetBool(params, "first_metric_only", &first_metric_only); GetBool(params, "first_metric_only", &first_metric_only);
...@@ -615,6 +621,8 @@ std::string Config::SaveMembersToString() const { ...@@ -615,6 +621,8 @@ std::string Config::SaveMembersToString() const {
str_buf << "[feature_fraction: " << feature_fraction << "]\n"; str_buf << "[feature_fraction: " << feature_fraction << "]\n";
str_buf << "[feature_fraction_bynode: " << feature_fraction_bynode << "]\n"; str_buf << "[feature_fraction_bynode: " << feature_fraction_bynode << "]\n";
str_buf << "[feature_fraction_seed: " << feature_fraction_seed << "]\n"; str_buf << "[feature_fraction_seed: " << feature_fraction_seed << "]\n";
str_buf << "[extra_trees: " << extra_trees << "]\n";
str_buf << "[extra_seed: " << extra_seed << "]\n";
str_buf << "[early_stopping_round: " << early_stopping_round << "]\n"; str_buf << "[early_stopping_round: " << early_stopping_round << "]\n";
str_buf << "[first_metric_only: " << first_metric_only << "]\n"; str_buf << "[first_metric_only: " << first_metric_only << "]\n";
str_buf << "[max_delta_step: " << max_delta_step << "]\n"; str_buf << "[max_delta_step: " << max_delta_step << "]\n";
......
...@@ -64,6 +64,7 @@ public: ...@@ -64,6 +64,7 @@ public:
find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdCategorical, this, std::placeholders::_1 find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdCategorical, this, std::placeholders::_1
, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6); , std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6);
} }
rand_ = Random(meta_->config->extra_seed);
} }
hist_t* RawData() { hist_t* RawData() {
...@@ -93,16 +94,36 @@ public: ...@@ -93,16 +94,36 @@ public:
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian, double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step); meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step);
double min_gain_shift = gain_shift + meta_->config->min_gain_to_split; double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
int rand_threshold = 0;
if (meta_->num_bin - 2 > 0){
rand_threshold = rand_.NextInt(0, meta_->num_bin - 2);
}
bool is_rand = meta_->config->extra_trees;
if (meta_->num_bin > 2 && meta_->missing_type != MissingType::None) { if (meta_->num_bin > 2 && meta_->missing_type != MissingType::None) {
if (meta_->missing_type == MissingType::Zero) { if (meta_->missing_type == MissingType::Zero) {
FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, true, false); if (is_rand) {
FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, true, false); FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, true, false, rand_threshold);
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, true, false, rand_threshold);
}
else {
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, true, false, rand_threshold);
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, true, false, rand_threshold);
}
} else { } else {
FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, true); if (is_rand) {
FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, false, true); FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, true, rand_threshold);
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, false, true, rand_threshold);
} else {
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, true, rand_threshold);
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, false, true, rand_threshold);
}
} }
} else { } else {
FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, false); if (is_rand) {
FindBestThresholdSequence<true>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, false, rand_threshold);
} else {
FindBestThresholdSequence<false>(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, false, rand_threshold);
}
// fix the direction error when only have 2 bins // fix the direction error when only have 2 bins
if (meta_->missing_type == MissingType::NaN) { if (meta_->missing_type == MissingType::NaN) {
output->default_left = false; output->default_left = false;
...@@ -192,6 +213,11 @@ public: ...@@ -192,6 +213,11 @@ public:
find_direction.push_back(-1); find_direction.push_back(-1);
start_position.push_back(used_bin - 1); start_position.push_back(used_bin - 1);
const int max_num_cat = std::min(meta_->config->max_cat_threshold, (used_bin + 1) / 2); const int max_num_cat = std::min(meta_->config->max_cat_threshold, (used_bin + 1) / 2);
int max_threshold = std::max(std::min(max_num_cat, used_bin) - 1, 0);
int rand_threshold = 0;
if (max_threshold > 0) {
rand_threshold = rand_.NextInt(0, max_threshold);
}
is_splittable_ = false; is_splittable_ = false;
for (size_t out_i = 0; out_i < find_direction.size(); ++out_i) { for (size_t out_i = 0; out_i < find_direction.size(); ++out_i) {
...@@ -227,6 +253,7 @@ public: ...@@ -227,6 +253,7 @@ public:
cnt_cur_group = 0; cnt_cur_group = 0;
double sum_right_gradient = sum_gradient - sum_left_gradient; double sum_right_gradient = sum_gradient - sum_left_gradient;
if (!meta_->config->extra_trees || i == rand_threshold) {
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian, double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->config->lambda_l1, l2, meta_->config->max_delta_step, meta_->config->lambda_l1, l2, meta_->config->max_delta_step,
min_constraint, max_constraint, 0); min_constraint, max_constraint, 0);
...@@ -243,6 +270,7 @@ public: ...@@ -243,6 +270,7 @@ public:
} }
} }
} }
}
if (is_splittable_) { if (is_splittable_) {
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian, output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian,
...@@ -516,8 +544,9 @@ private: ...@@ -516,8 +544,9 @@ private:
return -(2.0 * sg_l1 * output + (sum_hessians + l2) * output * output); return -(2.0 * sg_l1 * output + (sum_hessians + l2) * output * output);
} }
template<bool is_rand>
void FindBestThresholdSequence(double sum_gradient, double sum_hessian, data_size_t num_data, double min_constraint, double max_constraint, void FindBestThresholdSequence(double sum_gradient, double sum_hessian, data_size_t num_data, double min_constraint, double max_constraint,
double min_gain_shift, SplitInfo* output, int dir, bool skip_default_bin, bool use_na_as_missing) { double min_gain_shift, SplitInfo* output, int dir, bool skip_default_bin, bool use_na_as_missing, int rand_threshold) {
const int8_t offset = meta_->offset; const int8_t offset = meta_->offset;
double best_sum_left_gradient = NAN; double best_sum_left_gradient = NAN;
...@@ -557,6 +586,7 @@ private: ...@@ -557,6 +586,7 @@ private:
if (sum_left_hessian < meta_->config->min_sum_hessian_in_leaf) break; if (sum_left_hessian < meta_->config->min_sum_hessian_in_leaf) break;
double sum_left_gradient = sum_gradient - sum_right_gradient; double sum_left_gradient = sum_gradient - sum_right_gradient;
if (!is_rand || t - 1 + offset == rand_threshold) {
// current split gain // current split gain
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian, double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step, meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
...@@ -576,6 +606,7 @@ private: ...@@ -576,6 +606,7 @@ private:
best_gain = current_gain; best_gain = current_gain;
} }
} }
}
} else { } else {
double sum_left_gradient = 0.0f; double sum_left_gradient = 0.0f;
double sum_left_hessian = kEpsilon; double sum_left_hessian = kEpsilon;
...@@ -619,6 +650,7 @@ private: ...@@ -619,6 +650,7 @@ private:
if (sum_right_hessian < meta_->config->min_sum_hessian_in_leaf) break; if (sum_right_hessian < meta_->config->min_sum_hessian_in_leaf) break;
double sum_right_gradient = sum_gradient - sum_left_gradient; double sum_right_gradient = sum_gradient - sum_left_gradient;
if (!is_rand || t + offset == rand_threshold) {
// current split gain // current split gain
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian, double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step, meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
...@@ -638,6 +670,7 @@ private: ...@@ -638,6 +670,7 @@ private:
} }
} }
} }
}
if (is_splittable_ && best_gain > output->gain) { if (is_splittable_ && best_gain > output->gain) {
// update split information // update split information
...@@ -664,6 +697,8 @@ private: ...@@ -664,6 +697,8 @@ private:
/*! \brief sum of gradient of each bin */ /*! \brief sum of gradient of each bin */
hist_t* data_; hist_t* data_;
bool is_splittable_ = true; bool is_splittable_ = true;
/*! \brief random number generator for extremely randomized trees */
Random rand_;
std::function<void(double, double, data_size_t, double, double, SplitInfo*)> find_best_threshold_fun_; std::function<void(double, double, data_size_t, double, double, SplitInfo*)> find_best_threshold_fun_;
}; };
......
...@@ -1811,6 +1811,24 @@ class TestEngine(unittest.TestCase): ...@@ -1811,6 +1811,24 @@ class TestEngine(unittest.TestCase):
self.assertNotAlmostEqual(predicted[0], predicted[1]) self.assertNotAlmostEqual(predicted[0], predicted[1])
self.assertAlmostEqual(predicted[1], predicted[2]) self.assertAlmostEqual(predicted[1], predicted[2])
def test_extra_trees(self):
# check extra trees increases regularization
X, y = load_boston(True)
lgb_x = lgb.Dataset(X, label=y)
params = {'objective': 'regression',
'num_leaves': 32,
'verbose': -1,
'extra_trees': False,
'seed': 0}
est = lgb.train(params, lgb_x, num_boost_round=10)
predicted = est.predict(X)
err = mean_squared_error(y, predicted)
params['extra_trees'] = True
est = lgb.train(params, lgb_x, num_boost_round=10)
predicted_new = est.predict(X)
err_new = mean_squared_error(y, predicted_new)
self.assertLess(err, err_new)
@unittest.skipIf(not lgb.compat.PANDAS_INSTALLED, 'pandas is not installed') @unittest.skipIf(not lgb.compat.PANDAS_INSTALLED, 'pandas is not installed')
def test_trees_to_dataframe(self): def test_trees_to_dataframe(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment