Commit 2f984f3a authored by Allard van Mossel's avatar Allard van Mossel Committed by Guolin Ke
Browse files

Fixed bug when num_iterations > early_stopping_round (#50)

parent b24190a6
...@@ -217,9 +217,9 @@ void GBDT::Train() { ...@@ -217,9 +217,9 @@ void GBDT::Train() {
} }
} }
// close file // close file
if (early_stopping_round_ > 0) { int remaining_models = gbdt_config_->num_iterations - early_stopping_round_;
// save remaining models if (early_stopping_round_ > 0 && remaining_models > 0) {
for (int iter = gbdt_config_->num_iterations - early_stopping_round_; iter < static_cast<int>(models_.size()); ++iter){ for (int iter = remaining_models; iter < static_cast<int>(models_.size()); ++iter){
fprintf(output_model_file, "Tree=%d\n", iter); fprintf(output_model_file, "Tree=%d\n", iter);
fprintf(output_model_file, "%s\n", models_.at(iter)->ToString().c_str()); fprintf(output_model_file, "%s\n", models_.at(iter)->ToString().c_str());
} }
...@@ -254,7 +254,7 @@ bool GBDT::OutputMetric(int iter) { ...@@ -254,7 +254,7 @@ bool GBDT::OutputMetric(int iter) {
score_t test_score_ = valid_metrics_[i][j]->PrintAndGetLoss(iter, valid_score_updater_[i]->score()); score_t test_score_ = valid_metrics_[i][j]->PrintAndGetLoss(iter, valid_score_updater_[i]->score());
if (!ret && early_stopping_round_ > 0){ if (!ret && early_stopping_round_ > 0){
bool the_bigger_the_better_ = valid_metrics_[i][j]->the_bigger_the_better; bool the_bigger_the_better_ = valid_metrics_[i][j]->the_bigger_the_better;
if (best_score_[i][j] < 0 if (best_score_[i][j] < 0
|| (!the_bigger_the_better_ && test_score_ < best_score_[i][j]) || (!the_bigger_the_better_ && test_score_ < best_score_[i][j])
|| ( the_bigger_the_better_ && test_score_ > best_score_[i][j])){ || ( the_bigger_the_better_ && test_score_ > best_score_[i][j])){
best_score_[i][j] = test_score_; best_score_[i][j] = test_score_;
...@@ -390,7 +390,7 @@ void GBDT::FeatureImportance(const int last_iter) { ...@@ -390,7 +390,7 @@ void GBDT::FeatureImportance(const int last_iter) {
std::sort(pairs.begin(), pairs.end(), std::sort(pairs.begin(), pairs.end(),
[](const std::pair<size_t, std::string>& lhs, [](const std::pair<size_t, std::string>& lhs,
const std::pair<size_t, std::string>& rhs) { const std::pair<size_t, std::string>& rhs) {
return lhs.first > rhs.first; return lhs.first > rhs.first;
}); });
// write to model file // write to model file
fprintf(output_model_file, "\nfeature importances:\n"); fprintf(output_model_file, "\nfeature importances:\n");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment