Commit 25200c3a authored by Guolin Ke's avatar Guolin Ke
Browse files

fix #132

parent fca93b78
......@@ -146,12 +146,14 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
if (early_stopping_round_ > 0) {
best_iter_.emplace_back();
best_score_.emplace_back();
best_msg_.emplace_back();
}
for (const auto& metric : valid_metrics) {
valid_metrics_.back().push_back(metric);
if (early_stopping_round_ > 0) {
best_iter_.back().push_back(0);
best_score_.back().push_back(kMinScore);
best_msg_.back().emplace_back();
}
}
valid_metrics_.back().shrink_to_fit();
......@@ -278,10 +280,12 @@ void GBDT::RollbackOneIter() {
bool GBDT::EvalAndCheckEarlyStopping() {
bool is_met_early_stopping = false;
// print message for metric
is_met_early_stopping = OutputMetric(iter_);
auto best_msg = OutputMetric(iter_);
is_met_early_stopping = !best_msg.empty();
if (is_met_early_stopping) {
Log::Info("Early stopping at iteration %d, the best iteration round is %d",
iter_, iter_ - early_stopping_round_);
Log::Info("Output of best iteration round:\n%s", best_msg.c_str());
// pop last early_stopping_round_ models
for (int i = 0; i < early_stopping_round_ * num_class_; ++i) {
models_.pop_back();
......@@ -299,41 +303,62 @@ void GBDT::UpdateScore(const Tree* tree, const int curr_class) {
}
}
bool GBDT::OutputMetric(int iter) {
bool ret = false;
std::string GBDT::OutputMetric(int iter) {
bool need_output = (iter % gbdt_config_->output_freq) == 0;
std::string ret = "";
std::stringstream msg_buf;
std::vector<std::pair<int, int>> meet_early_stopping_pairs;
// print training metric
if ((iter % gbdt_config_->output_freq) == 0) {
if (need_output) {
for (auto& sub_metric : training_metrics_) {
auto name = sub_metric->GetName();
auto scores = sub_metric->Eval(train_score_updater_->score());
for (size_t k = 0; k < name.size(); ++k) {
Log::Info("Iteration:%d, training %s : %f", iter, name[k].c_str(), scores[k]);
std::stringstream tmp_buf;
tmp_buf << "Iteration:" << iter
<< ", training " << name[k]
<< " : " << scores[k];
Log::Info(tmp_buf.str().c_str());
if (early_stopping_round_ > 0) {
msg_buf << tmp_buf.str() << std::endl;
}
}
}
}
// print validation metric
if ((iter % gbdt_config_->output_freq) == 0 || early_stopping_round_ > 0) {
if (need_output || early_stopping_round_ > 0) {
for (size_t i = 0; i < valid_metrics_.size(); ++i) {
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score());
if ((iter % gbdt_config_->output_freq) == 0) {
auto name = valid_metrics_[i][j]->GetName();
for (size_t k = 0; k < name.size(); ++k) {
Log::Info("Iteration:%d, valid_%d %s : %f", iter, i + 1, name[k].c_str(), test_scores[k]);
auto name = valid_metrics_[i][j]->GetName();
for (size_t k = 0; k < name.size(); ++k) {
std::stringstream tmp_buf;
tmp_buf << "Iteration:" << iter
<< ", valid_" << i + 1 << " " << name[k]
<< " : " << test_scores[k];
if (need_output) {
Log::Info(tmp_buf.str().c_str());
}
if (early_stopping_round_ > 0) {
msg_buf << tmp_buf.str() << std::endl;
}
}
if (!ret && early_stopping_round_ > 0) {
if (ret.empty() && early_stopping_round_ > 0) {
auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
if (cur_score > best_score_[i][j]) {
best_score_[i][j] = cur_score;
best_iter_[i][j] = iter;
meet_early_stopping_pairs.emplace_back(i, j);
} else {
if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = true; }
if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = best_msg_[i][j]; }
}
}
}
}
}
for (auto& pair : meet_early_stopping_pairs) {
best_msg_[pair.first][pair.second] = msg_buf.str();
}
return ret;
}
......@@ -419,14 +444,14 @@ void GBDT::Boosting() {
}
std::string GBDT::DumpModel() const {
std::stringstream ss;
std::stringstream str_buf;
ss << "{";
ss << "\"name\":\"" << Name() << "\"," << std::endl;
ss << "\"num_class\":" << num_class_ << "," << std::endl;
ss << "\"label_index\":" << label_idx_ << "," << std::endl;
ss << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
ss << "\"sigmoid\":" << sigmoid_ << "," << std::endl;
str_buf << "{";
str_buf << "\"name\":\"" << Name() << "\"," << std::endl;
str_buf << "\"num_class\":" << num_class_ << "," << std::endl;
str_buf << "\"label_index\":" << label_idx_ << "," << std::endl;
str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
str_buf << "\"sigmoid\":" << sigmoid_ << "," << std::endl;
// output feature names
auto feature_names = std::ref(feature_names_);
......@@ -434,25 +459,25 @@ std::string GBDT::DumpModel() const {
feature_names = std::ref(train_data_->feature_names());
}
ss << "\"feature_names\":[\""
str_buf << "\"feature_names\":[\""
<< Common::Join(feature_names.get(), "\",\"") << "\"],"
<< std::endl;
ss << "\"tree_info\":[";
str_buf << "\"tree_info\":[";
for (int i = 0; i < static_cast<int>(models_.size()); ++i) {
if (i > 0) {
ss << ",";
str_buf << ",";
}
ss << "{";
ss << "\"tree_index\":" << i << ",";
ss << models_[i]->ToJSON();
ss << "}";
str_buf << "{";
str_buf << "\"tree_index\":" << i << ",";
str_buf << models_[i]->ToJSON();
str_buf << "}";
}
ss << "]" << std::endl;
str_buf << "]" << std::endl;
ss << "}" << std::endl;
str_buf << "}" << std::endl;
return ss.str();
return str_buf.str();
}
void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
......
......@@ -225,8 +225,9 @@ protected:
/*!
* \brief Print metric result of current iteration
* \param iter Current interation
* \return best_msg if met early_stopping
*/
bool OutputMetric(int iter);
std::string OutputMetric(int iter);
/*!
* \brief Calculate feature importances
* \param last_iter Last tree use to calculate
......@@ -252,9 +253,12 @@ protected:
std::vector<std::vector<const Metric*>> valid_metrics_;
/*! \brief Number of rounds for early stopping */
int early_stopping_round_;
/*! \brief Best score(s) for early stopping */
/*! \brief Best iteration(s) for early stopping */
std::vector<std::vector<int>> best_iter_;
/*! \brief Best score(s) for early stopping */
std::vector<std::vector<double>> best_score_;
/*! \brief output message of best iteration */
std::vector<std::vector<std::string>> best_msg_;
/*! \brief Trained models(trees) */
std::vector<std::unique_ptr<Tree>> models_;
/*! \brief Max feature index of training data*/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment