Commit 25200c3a authored by Guolin Ke's avatar Guolin Ke
Browse files

fix #132

parent fca93b78
...@@ -146,12 +146,14 @@ void GBDT::AddValidDataset(const Dataset* valid_data, ...@@ -146,12 +146,14 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
if (early_stopping_round_ > 0) { if (early_stopping_round_ > 0) {
best_iter_.emplace_back(); best_iter_.emplace_back();
best_score_.emplace_back(); best_score_.emplace_back();
best_msg_.emplace_back();
} }
for (const auto& metric : valid_metrics) { for (const auto& metric : valid_metrics) {
valid_metrics_.back().push_back(metric); valid_metrics_.back().push_back(metric);
if (early_stopping_round_ > 0) { if (early_stopping_round_ > 0) {
best_iter_.back().push_back(0); best_iter_.back().push_back(0);
best_score_.back().push_back(kMinScore); best_score_.back().push_back(kMinScore);
best_msg_.back().emplace_back();
} }
} }
valid_metrics_.back().shrink_to_fit(); valid_metrics_.back().shrink_to_fit();
...@@ -278,10 +280,12 @@ void GBDT::RollbackOneIter() { ...@@ -278,10 +280,12 @@ void GBDT::RollbackOneIter() {
bool GBDT::EvalAndCheckEarlyStopping() { bool GBDT::EvalAndCheckEarlyStopping() {
bool is_met_early_stopping = false; bool is_met_early_stopping = false;
// print message for metric // print message for metric
is_met_early_stopping = OutputMetric(iter_); auto best_msg = OutputMetric(iter_);
is_met_early_stopping = !best_msg.empty();
if (is_met_early_stopping) { if (is_met_early_stopping) {
Log::Info("Early stopping at iteration %d, the best iteration round is %d", Log::Info("Early stopping at iteration %d, the best iteration round is %d",
iter_, iter_ - early_stopping_round_); iter_, iter_ - early_stopping_round_);
Log::Info("Output of best iteration round:\n%s", best_msg.c_str());
// pop last early_stopping_round_ models // pop last early_stopping_round_ models
for (int i = 0; i < early_stopping_round_ * num_class_; ++i) { for (int i = 0; i < early_stopping_round_ * num_class_; ++i) {
models_.pop_back(); models_.pop_back();
...@@ -299,41 +303,62 @@ void GBDT::UpdateScore(const Tree* tree, const int curr_class) { ...@@ -299,41 +303,62 @@ void GBDT::UpdateScore(const Tree* tree, const int curr_class) {
} }
} }
bool GBDT::OutputMetric(int iter) { std::string GBDT::OutputMetric(int iter) {
bool ret = false; bool need_output = (iter % gbdt_config_->output_freq) == 0;
std::string ret = "";
std::stringstream msg_buf;
std::vector<std::pair<int, int>> meet_early_stopping_pairs;
// print training metric // print training metric
if ((iter % gbdt_config_->output_freq) == 0) { if (need_output) {
for (auto& sub_metric : training_metrics_) { for (auto& sub_metric : training_metrics_) {
auto name = sub_metric->GetName(); auto name = sub_metric->GetName();
auto scores = sub_metric->Eval(train_score_updater_->score()); auto scores = sub_metric->Eval(train_score_updater_->score());
for (size_t k = 0; k < name.size(); ++k) { for (size_t k = 0; k < name.size(); ++k) {
Log::Info("Iteration:%d, training %s : %f", iter, name[k].c_str(), scores[k]); std::stringstream tmp_buf;
tmp_buf << "Iteration:" << iter
<< ", training " << name[k]
<< " : " << scores[k];
Log::Info(tmp_buf.str().c_str());
if (early_stopping_round_ > 0) {
msg_buf << tmp_buf.str() << std::endl;
}
} }
} }
} }
// print validation metric // print validation metric
if ((iter % gbdt_config_->output_freq) == 0 || early_stopping_round_ > 0) { if (need_output || early_stopping_round_ > 0) {
for (size_t i = 0; i < valid_metrics_.size(); ++i) { for (size_t i = 0; i < valid_metrics_.size(); ++i) {
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) { for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score()); auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score());
if ((iter % gbdt_config_->output_freq) == 0) { auto name = valid_metrics_[i][j]->GetName();
auto name = valid_metrics_[i][j]->GetName(); for (size_t k = 0; k < name.size(); ++k) {
for (size_t k = 0; k < name.size(); ++k) { std::stringstream tmp_buf;
Log::Info("Iteration:%d, valid_%d %s : %f", iter, i + 1, name[k].c_str(), test_scores[k]); tmp_buf << "Iteration:" << iter
<< ", valid_" << i + 1 << " " << name[k]
<< " : " << test_scores[k];
if (need_output) {
Log::Info(tmp_buf.str().c_str());
}
if (early_stopping_round_ > 0) {
msg_buf << tmp_buf.str() << std::endl;
} }
} }
if (!ret && early_stopping_round_ > 0) { if (ret.empty() && early_stopping_round_ > 0) {
auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back(); auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
if (cur_score > best_score_[i][j]) { if (cur_score > best_score_[i][j]) {
best_score_[i][j] = cur_score; best_score_[i][j] = cur_score;
best_iter_[i][j] = iter; best_iter_[i][j] = iter;
meet_early_stopping_pairs.emplace_back(i, j);
} else { } else {
if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = true; } if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = best_msg_[i][j]; }
} }
} }
} }
} }
} }
for (auto& pair : meet_early_stopping_pairs) {
best_msg_[pair.first][pair.second] = msg_buf.str();
}
return ret; return ret;
} }
...@@ -419,14 +444,14 @@ void GBDT::Boosting() { ...@@ -419,14 +444,14 @@ void GBDT::Boosting() {
} }
std::string GBDT::DumpModel() const { std::string GBDT::DumpModel() const {
std::stringstream ss; std::stringstream str_buf;
ss << "{"; str_buf << "{";
ss << "\"name\":\"" << Name() << "\"," << std::endl; str_buf << "\"name\":\"" << Name() << "\"," << std::endl;
ss << "\"num_class\":" << num_class_ << "," << std::endl; str_buf << "\"num_class\":" << num_class_ << "," << std::endl;
ss << "\"label_index\":" << label_idx_ << "," << std::endl; str_buf << "\"label_index\":" << label_idx_ << "," << std::endl;
ss << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl; str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
ss << "\"sigmoid\":" << sigmoid_ << "," << std::endl; str_buf << "\"sigmoid\":" << sigmoid_ << "," << std::endl;
// output feature names // output feature names
auto feature_names = std::ref(feature_names_); auto feature_names = std::ref(feature_names_);
...@@ -434,25 +459,25 @@ std::string GBDT::DumpModel() const { ...@@ -434,25 +459,25 @@ std::string GBDT::DumpModel() const {
feature_names = std::ref(train_data_->feature_names()); feature_names = std::ref(train_data_->feature_names());
} }
ss << "\"feature_names\":[\"" str_buf << "\"feature_names\":[\""
<< Common::Join(feature_names.get(), "\",\"") << "\"]," << Common::Join(feature_names.get(), "\",\"") << "\"],"
<< std::endl; << std::endl;
ss << "\"tree_info\":["; str_buf << "\"tree_info\":[";
for (int i = 0; i < static_cast<int>(models_.size()); ++i) { for (int i = 0; i < static_cast<int>(models_.size()); ++i) {
if (i > 0) { if (i > 0) {
ss << ","; str_buf << ",";
} }
ss << "{"; str_buf << "{";
ss << "\"tree_index\":" << i << ","; str_buf << "\"tree_index\":" << i << ",";
ss << models_[i]->ToJSON(); str_buf << models_[i]->ToJSON();
ss << "}"; str_buf << "}";
} }
ss << "]" << std::endl; str_buf << "]" << std::endl;
ss << "}" << std::endl; str_buf << "}" << std::endl;
return ss.str(); return str_buf.str();
} }
void GBDT::SaveModelToFile(int num_iteration, const char* filename) const { void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
......
...@@ -225,8 +225,9 @@ protected: ...@@ -225,8 +225,9 @@ protected:
/*! /*!
* \brief Print metric result of current iteration * \brief Print metric result of current iteration
* \param iter Current interation * \param iter Current interation
* \return best_msg if met early_stopping
*/ */
bool OutputMetric(int iter); std::string OutputMetric(int iter);
/*! /*!
* \brief Calculate feature importances * \brief Calculate feature importances
* \param last_iter Last tree use to calculate * \param last_iter Last tree use to calculate
...@@ -252,9 +253,12 @@ protected: ...@@ -252,9 +253,12 @@ protected:
std::vector<std::vector<const Metric*>> valid_metrics_; std::vector<std::vector<const Metric*>> valid_metrics_;
/*! \brief Number of rounds for early stopping */ /*! \brief Number of rounds for early stopping */
int early_stopping_round_; int early_stopping_round_;
/*! \brief Best score(s) for early stopping */ /*! \brief Best iteration(s) for early stopping */
std::vector<std::vector<int>> best_iter_; std::vector<std::vector<int>> best_iter_;
/*! \brief Best score(s) for early stopping */
std::vector<std::vector<double>> best_score_; std::vector<std::vector<double>> best_score_;
/*! \brief output message of best iteration */
std::vector<std::vector<std::string>> best_msg_;
/*! \brief Trained models(trees) */ /*! \brief Trained models(trees) */
std::vector<std::unique_ptr<Tree>> models_; std::vector<std::unique_ptr<Tree>> models_;
/*! \brief Max feature index of training data*/ /*! \brief Max feature index of training data*/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment