Unverified Commit f01b2aca authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

first metric only in earlystopping for cli (#2172)

* first metric only in earlystopping for cli

* code clean

* added note about CLI only usage

* removed note about CLI only usage
parent c9d681ac
...@@ -240,6 +240,10 @@ Learning Control Parameters ...@@ -240,6 +240,10 @@ Learning Control Parameters
- ``<= 0`` means disable - ``<= 0`` means disable
- ``first_metric_only`` :raw-html:`<a id="first_metric_only" title="Permalink to this parameter" href="#first_metric_only">&#x1F517;&#xFE0E;</a>`, default = ``false``, type = bool
- set this to ``true``, if you want to use only the first metric for early stopping
- ``max_delta_step`` :raw-html:`<a id="max_delta_step" title="Permalink to this parameter" href="#max_delta_step">&#x1F517;&#xFE0E;</a>`, default = ``0.0``, type = double, aliases: ``max_tree_output``, ``max_leaf_output`` - ``max_delta_step`` :raw-html:`<a id="max_delta_step" title="Permalink to this parameter" href="#max_delta_step">&#x1F517;&#xFE0E;</a>`, default = ``0.0``, type = double, aliases: ``max_tree_output``, ``max_leaf_output``
- used to limit the max output of tree leaves - used to limit the max output of tree leaves
......
...@@ -260,6 +260,9 @@ struct Config { ...@@ -260,6 +260,9 @@ struct Config {
// desc = ``<= 0`` means disable // desc = ``<= 0`` means disable
int early_stopping_round = 0; int early_stopping_round = 0;
// desc = set this to ``true``, if you want to use only the first metric for early stopping
bool first_metric_only = false;
// alias = max_tree_output, max_leaf_output // alias = max_tree_output, max_leaf_output
// desc = used to limit the max output of tree leaves // desc = used to limit the max output of tree leaves
// desc = ``<= 0`` means no constraint // desc = ``<= 0`` means no constraint
......
...@@ -22,6 +22,7 @@ GBDT::GBDT() : iter_(0), ...@@ -22,6 +22,7 @@ GBDT::GBDT() : iter_(0),
train_data_(nullptr), train_data_(nullptr),
objective_function_(nullptr), objective_function_(nullptr),
early_stopping_round_(0), early_stopping_round_(0),
es_first_metric_only_(false),
max_feature_idx_(0), max_feature_idx_(0),
num_tree_per_iteration_(1), num_tree_per_iteration_(1),
num_class_(1), num_class_(1),
...@@ -51,6 +52,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective ...@@ -51,6 +52,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
num_class_ = config->num_class; num_class_ = config->num_class;
config_ = std::unique_ptr<Config>(new Config(*config)); config_ = std::unique_ptr<Config>(new Config(*config));
early_stopping_round_ = config_->early_stopping_round; early_stopping_round_ = config_->early_stopping_round;
es_first_metric_only_ = config_->first_metric_only;
shrinkage_rate_ = config_->learning_rate; shrinkage_rate_ = config_->learning_rate;
std::string forced_splits_path = config->forcedsplits_filename; std::string forced_splits_path = config->forcedsplits_filename;
...@@ -129,20 +131,18 @@ void GBDT::AddValidDataset(const Dataset* valid_data, ...@@ -129,20 +131,18 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
} }
valid_score_updater_.push_back(std::move(new_score_updater)); valid_score_updater_.push_back(std::move(new_score_updater));
valid_metrics_.emplace_back(); valid_metrics_.emplace_back();
if (early_stopping_round_ > 0) {
best_iter_.emplace_back();
best_score_.emplace_back();
best_msg_.emplace_back();
}
for (const auto& metric : valid_metrics) { for (const auto& metric : valid_metrics) {
valid_metrics_.back().push_back(metric); valid_metrics_.back().push_back(metric);
if (early_stopping_round_ > 0) {
best_iter_.back().push_back(0);
best_score_.back().push_back(kMinScore);
best_msg_.back().emplace_back();
}
} }
valid_metrics_.back().shrink_to_fit(); valid_metrics_.back().shrink_to_fit();
if (early_stopping_round_ > 0) {
auto num_metrics = valid_metrics.size();
if (es_first_metric_only_) { num_metrics = 1; }
best_iter_.emplace_back(num_metrics, 0);
best_score_.emplace_back(num_metrics, kMinScore);
best_msg_.emplace_back(num_metrics);
}
} }
void GBDT::Boosting() { void GBDT::Boosting() {
...@@ -514,6 +514,7 @@ std::string GBDT::OutputMetric(int iter) { ...@@ -514,6 +514,7 @@ std::string GBDT::OutputMetric(int iter) {
msg_buf << tmp_buf.str() << '\n'; msg_buf << tmp_buf.str() << '\n';
} }
} }
if (es_first_metric_only_ && j > 0) { continue; }
if (ret.empty() && early_stopping_round_ > 0) { if (ret.empty() && early_stopping_round_ > 0) {
auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back(); auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
if (cur_score > best_score_[i][j]) { if (cur_score > best_score_[i][j]) {
......
...@@ -434,6 +434,8 @@ class GBDT : public GBDTBase { ...@@ -434,6 +434,8 @@ class GBDT : public GBDTBase {
std::vector<std::vector<const Metric*>> valid_metrics_; std::vector<std::vector<const Metric*>> valid_metrics_;
/*! \brief Number of rounds for early stopping */ /*! \brief Number of rounds for early stopping */
int early_stopping_round_; int early_stopping_round_;
/*! \brief Only use first metric for early stopping */
bool es_first_metric_only_;
/*! \brief Best iteration(s) for early stopping */ /*! \brief Best iteration(s) for early stopping */
std::vector<std::vector<int>> best_iter_; std::vector<std::vector<int>> best_iter_;
/*! \brief Best score(s) for early stopping */ /*! \brief Best score(s) for early stopping */
......
...@@ -181,6 +181,7 @@ std::unordered_set<std::string> Config::parameter_set({ ...@@ -181,6 +181,7 @@ std::unordered_set<std::string> Config::parameter_set({
"feature_fraction", "feature_fraction",
"feature_fraction_seed", "feature_fraction_seed",
"early_stopping_round", "early_stopping_round",
"first_metric_only",
"max_delta_step", "max_delta_step",
"lambda_l1", "lambda_l1",
"lambda_l2", "lambda_l2",
...@@ -312,6 +313,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -312,6 +313,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetInt(params, "early_stopping_round", &early_stopping_round); GetInt(params, "early_stopping_round", &early_stopping_round);
GetBool(params, "first_metric_only", &first_metric_only);
GetDouble(params, "max_delta_step", &max_delta_step); GetDouble(params, "max_delta_step", &max_delta_step);
GetDouble(params, "lambda_l1", &lambda_l1); GetDouble(params, "lambda_l1", &lambda_l1);
...@@ -556,6 +559,7 @@ std::string Config::SaveMembersToString() const { ...@@ -556,6 +559,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[feature_fraction: " << feature_fraction << "]\n"; str_buf << "[feature_fraction: " << feature_fraction << "]\n";
str_buf << "[feature_fraction_seed: " << feature_fraction_seed << "]\n"; str_buf << "[feature_fraction_seed: " << feature_fraction_seed << "]\n";
str_buf << "[early_stopping_round: " << early_stopping_round << "]\n"; str_buf << "[early_stopping_round: " << early_stopping_round << "]\n";
str_buf << "[first_metric_only: " << first_metric_only << "]\n";
str_buf << "[max_delta_step: " << max_delta_step << "]\n"; str_buf << "[max_delta_step: " << max_delta_step << "]\n";
str_buf << "[lambda_l1: " << lambda_l1 << "]\n"; str_buf << "[lambda_l1: " << lambda_l1 << "]\n";
str_buf << "[lambda_l2: " << lambda_l2 << "]\n"; str_buf << "[lambda_l2: " << lambda_l2 << "]\n";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment