Commit edd618bf authored by Guolin Ke's avatar Guolin Ke
Browse files

add model output snapshot.

parent df52c128
...@@ -90,6 +90,7 @@ public: ...@@ -90,6 +90,7 @@ public:
int data_random_seed = 1; int data_random_seed = 1;
std::string data_filename = ""; std::string data_filename = "";
std::vector<std::string> valid_data_filenames; std::vector<std::string> valid_data_filenames;
int snapshot_freq = 100;
std::string output_model = "LightGBM_model.txt"; std::string output_model = "LightGBM_model.txt";
std::string output_result = "LightGBM_predict_result.txt"; std::string output_result = "LightGBM_predict_result.txt";
std::string input_model = ""; std::string input_model = "";
......
...@@ -231,6 +231,11 @@ void Application::Train() { ...@@ -231,6 +231,11 @@ void Application::Train() {
// output used time per iteration // output used time per iteration
Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double, Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double,
std::milli>(end_time - start_time) * 1e-3, iter + 1); std::milli>(end_time - start_time) * 1e-3, iter + 1);
if (config_.io_config.snapshot_freq > 0
&& (iter+1) % config_.io_config.snapshot_freq == 0) {
std::string snapshot_out = config_.io_config.output_model + ".snapshot_iter_" + std::to_string(iter + 1);
boosting_->SaveModelToFile(-1, snapshot_out.c_str());
}
} }
// save model to file // save model to file
boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str()); boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str());
......
...@@ -207,6 +207,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -207,6 +207,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file); GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
GetBool(params, "is_predict_raw_score", &is_predict_raw_score); GetBool(params, "is_predict_raw_score", &is_predict_raw_score);
GetBool(params, "is_predict_leaf_index", &is_predict_leaf_index); GetBool(params, "is_predict_leaf_index", &is_predict_leaf_index);
GetInt(params, "snapshot_freq", &snapshot_freq);
GetString(params, "output_model", &output_model); GetString(params, "output_model", &output_model);
GetString(params, "input_model", &input_model); GetString(params, "input_model", &input_model);
GetString(params, "output_result", &output_result); GetString(params, "output_result", &output_result);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment