Unverified Commit f01b2aca authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

first metric only in earlystopping for cli (#2172)

* first metric only in earlystopping for cli

* code clean

* added note about CLI only usage

* removed note about CLI only usage
parent c9d681ac
......@@ -240,6 +240,10 @@ Learning Control Parameters
- ``<= 0`` means disable
- ``first_metric_only`` :raw-html:`<a id="first_metric_only" title="Permalink to this parameter" href="#first_metric_only">&#x1F517;&#xFE0E;</a>`, default = ``false``, type = bool
- set this to ``true``, if you want to use only the first metric for early stopping
- ``max_delta_step`` :raw-html:`<a id="max_delta_step" title="Permalink to this parameter" href="#max_delta_step">&#x1F517;&#xFE0E;</a>`, default = ``0.0``, type = double, aliases: ``max_tree_output``, ``max_leaf_output``
- used to limit the max output of tree leaves
......
......@@ -260,6 +260,9 @@ struct Config {
// desc = ``<= 0`` means disable
int early_stopping_round = 0;
// desc = set this to ``true``, if you want to use only the first metric for early stopping
bool first_metric_only = false;
// alias = max_tree_output, max_leaf_output
// desc = used to limit the max output of tree leaves
// desc = ``<= 0`` means no constraint
......
......@@ -22,6 +22,7 @@ GBDT::GBDT() : iter_(0),
train_data_(nullptr),
objective_function_(nullptr),
early_stopping_round_(0),
es_first_metric_only_(false),
max_feature_idx_(0),
num_tree_per_iteration_(1),
num_class_(1),
......@@ -51,6 +52,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
num_class_ = config->num_class;
config_ = std::unique_ptr<Config>(new Config(*config));
early_stopping_round_ = config_->early_stopping_round;
es_first_metric_only_ = config_->first_metric_only;
shrinkage_rate_ = config_->learning_rate;
std::string forced_splits_path = config->forcedsplits_filename;
......@@ -129,20 +131,18 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
}
valid_score_updater_.push_back(std::move(new_score_updater));
valid_metrics_.emplace_back();
if (early_stopping_round_ > 0) {
best_iter_.emplace_back();
best_score_.emplace_back();
best_msg_.emplace_back();
}
for (const auto& metric : valid_metrics) {
valid_metrics_.back().push_back(metric);
if (early_stopping_round_ > 0) {
best_iter_.back().push_back(0);
best_score_.back().push_back(kMinScore);
best_msg_.back().emplace_back();
}
}
valid_metrics_.back().shrink_to_fit();
if (early_stopping_round_ > 0) {
auto num_metrics = valid_metrics.size();
if (es_first_metric_only_) { num_metrics = 1; }
best_iter_.emplace_back(num_metrics, 0);
best_score_.emplace_back(num_metrics, kMinScore);
best_msg_.emplace_back(num_metrics);
}
}
void GBDT::Boosting() {
......@@ -514,6 +514,7 @@ std::string GBDT::OutputMetric(int iter) {
msg_buf << tmp_buf.str() << '\n';
}
}
if (es_first_metric_only_ && j > 0) { continue; }
if (ret.empty() && early_stopping_round_ > 0) {
auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
if (cur_score > best_score_[i][j]) {
......
......@@ -434,6 +434,8 @@ class GBDT : public GBDTBase {
std::vector<std::vector<const Metric*>> valid_metrics_;
/*! \brief Number of rounds for early stopping */
int early_stopping_round_;
/*! \brief Only use first metric for early stopping */
bool es_first_metric_only_;
/*! \brief Best iteration(s) for early stopping */
std::vector<std::vector<int>> best_iter_;
/*! \brief Best score(s) for early stopping */
......
......@@ -181,6 +181,7 @@ std::unordered_set<std::string> Config::parameter_set({
"feature_fraction",
"feature_fraction_seed",
"early_stopping_round",
"first_metric_only",
"max_delta_step",
"lambda_l1",
"lambda_l2",
......@@ -312,6 +313,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetInt(params, "early_stopping_round", &early_stopping_round);
GetBool(params, "first_metric_only", &first_metric_only);
GetDouble(params, "max_delta_step", &max_delta_step);
GetDouble(params, "lambda_l1", &lambda_l1);
......@@ -556,6 +559,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[feature_fraction: " << feature_fraction << "]\n";
str_buf << "[feature_fraction_seed: " << feature_fraction_seed << "]\n";
str_buf << "[early_stopping_round: " << early_stopping_round << "]\n";
str_buf << "[first_metric_only: " << first_metric_only << "]\n";
str_buf << "[max_delta_step: " << max_delta_step << "]\n";
str_buf << "[lambda_l1: " << lambda_l1 << "]\n";
str_buf << "[lambda_l2: " << lambda_l2 << "]\n";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment