Commit 563e1464 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

update early stopping feature (#25)

update early stopping feature (#25)
parent 2d0e8fc9
......@@ -310,7 +310,9 @@ struct ParameterAlias {
{ "two_round", "use_two_round_loading" },
{ "mlist", "machine_list_file" },
{ "is_save_binary", "is_save_binary_file" },
{ "save_binary", "is_save_binary_file" }
{ "save_binary", "is_save_binary_file" },
{ "early_stopping_rounds", "early_stopping_round"},
{ "early_stopping", "early_stopping_round"}
});
std::unordered_map<std::string, std::string> tmp_map;
for (const auto& pair : *params) {
......
......@@ -13,7 +13,6 @@
#include <string>
#include <vector>
namespace LightGBM {
GBDT::GBDT(const BoostingConfig* config)
......@@ -185,19 +184,44 @@ void GBDT::Train() {
UpdateScore(new_tree);
UpdateScoreOutOfBag(new_tree);
// print message for metric
if (OutputMetric(iter + 1)) return;
bool is_early_stopping = OutputMetric(iter + 1);
// add model
models_.push_back(new_tree);
// save model to file per iteration
fprintf(output_model_file, "Tree=%d\n", iter);
fprintf(output_model_file, "%s\n", new_tree->ToString().c_str());
fflush(output_model_file);
if (early_stopping_round_ > 0){
// if use early stopping, save previous model at (iter - early_stopping_round_) iteration
if (iter >= early_stopping_round_){
fprintf(output_model_file, "Tree=%d\n", iter - early_stopping_round_);
Tree * printing_tree = models_.at(iter - early_stopping_round_);
fprintf(output_model_file, "%s\n", printing_tree->ToString().c_str());
fflush(output_model_file);
}
}
else{
fprintf(output_model_file, "Tree=%d\n", iter);
fprintf(output_model_file, "%s\n", new_tree->ToString().c_str());
fflush(output_model_file);
}
auto end_time = std::chrono::high_resolution_clock::now();
// output used time per iteration
Log::Stdout("%f seconds elapsed, finished %d iteration", std::chrono::duration<double,
std::milli>(end_time - start_time) * 1e-3, iter + 1);
if (is_early_stopping) {
// close file with an early-stopping message
Log::Stdout("early stopping at iteration %d, the best iteration round is %d", iter + 1, iter + 1 - early_stopping_round_);
fclose(output_model_file);
return;
}
}
// close file
if (early_stopping_round_ > 0) {
// save remaining models
for (int iter = gbdt_config_->num_iterations - early_stopping_round_; iter < models_.size(); ++iter){
fprintf(output_model_file, "Tree=%d\n", iter);
fprintf(output_model_file, "%s\n", models_.at(iter)->ToString().c_str());
}
fflush(output_model_file);
}
fclose(output_model_file);
}
......
......@@ -52,7 +52,7 @@ public:
void Print(int iter, const score_t* score, score_t& loss) const override {
score_t sum_loss = 0.0f;
if (early_stopping_round_ > 0 || output_freq_ > 0 && iter % output_freq_ == 0) {
if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
......@@ -171,7 +171,7 @@ public:
}
void Print(int iter, const score_t* score, score_t& loss) const override {
if (early_stopping_round_ > 0 || output_freq_ > 0 && iter % output_freq_ == 0) {
if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
// get indices sorted by score, descent order
std::vector<data_size_t> sorted_idx;
for (data_size_t i = 0; i < num_data_; ++i) {
......
......@@ -76,7 +76,7 @@ public:
}
void Print(int iter, const score_t* score, score_t& loss) const override {
if (early_stopping_round_ > 0 || output_freq_ > 0 && iter % output_freq_ == 0) {
if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
// some buffers for multi-threading sum up
std::vector<std::vector<double>> result_buffer_;
for (int i = 0; i < num_threads_; ++i) {
......
......@@ -43,7 +43,7 @@ public:
}
void Print(int iter, const score_t* score, score_t& loss) const override {
if (early_stopping_round_ > 0 || output_freq_ > 0 && iter % output_freq_ == 0) {
if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
score_t sum_loss = 0.0;
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment