Commit 29900dd7 authored by Guolin Ke's avatar Guolin Ke
Browse files

move the training into boosting.

parent 8a2b644d
...@@ -56,6 +56,8 @@ public: ...@@ -56,6 +56,8 @@ public:
virtual void AddValidDataset(const Dataset* valid_data, virtual void AddValidDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) = 0; const std::vector<const Metric*>& valid_metrics) = 0;
virtual void Train(int snapshot_freq, const std::string& model_output_path) = 0;
/*! /*!
* \brief Training logic * \brief Training logic
* \param gradient nullptr for using default objective, otherwise use self-defined boosting * \param gradient nullptr for using default objective, otherwise use self-defined boosting
......
...@@ -224,24 +224,7 @@ void Application::InitTrain() { ...@@ -224,24 +224,7 @@ void Application::InitTrain() {
void Application::Train() { void Application::Train() {
Log::Info("Started training..."); Log::Info("Started training...");
int total_iter = config_.boosting_config.num_iterations; boosting_->Train(config_.io_config.snapshot_freq, config_.io_config.output_model);
bool is_finished = false;
bool need_eval = true;
auto start_time = std::chrono::steady_clock::now();
for (int iter = 0; iter < total_iter && !is_finished; ++iter) {
is_finished = boosting_->TrainOneIter(nullptr, nullptr, need_eval);
auto end_time = std::chrono::steady_clock::now();
// output used time per iteration
Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double,
std::milli>(end_time - start_time) * 1e-3, iter + 1);
if (config_.io_config.snapshot_freq > 0
&& (iter+1) % config_.io_config.snapshot_freq == 0) {
std::string snapshot_out = config_.io_config.output_model + ".snapshot_iter_" + std::to_string(iter + 1);
boosting_->SaveModelToFile(-1, snapshot_out.c_str());
}
}
// save model to file
boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str());
// convert model to if-else statement code // convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) { if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str()); boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
......
...@@ -442,6 +442,25 @@ double ObtainAutomaticInitialScore(const ObjectiveFunction* objf, const float* l ...@@ -442,6 +442,25 @@ double ObtainAutomaticInitialScore(const ObjectiveFunction* objf, const float* l
} }
} }
void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
bool is_finished = false;
bool need_eval = true;
auto start_time = std::chrono::steady_clock::now();
for (int iter = 0; iter < gbdt_config_->num_iterations && !is_finished; ++iter) {
is_finished = TrainOneIter(nullptr, nullptr, need_eval);
auto end_time = std::chrono::steady_clock::now();
// output used time per iteration
Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double,
std::milli>(end_time - start_time) * 1e-3, iter + 1);
if (snapshot_freq > 0
&& (iter + 1) % snapshot_freq == 0) {
std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1);
SaveModelToFile(-1, snapshot_out.c_str());
}
}
SaveModelToFile(-1, model_output_path.c_str());
}
bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) { bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
// boosting from average label; or customized "average" if implemented for the current objective // boosting from average label; or customized "average" if implemented for the current objective
if (models_.empty() if (models_.empty()
...@@ -461,6 +480,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -461,6 +480,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
models_.push_back(std::move(new_tree)); models_.push_back(std::move(new_tree));
boost_from_average_ = true; boost_from_average_ = true;
} }
// boosting first // boosting first
if (gradient == nullptr || hessian == nullptr) { if (gradient == nullptr || hessian == nullptr) {
#ifdef TIMETAG #ifdef TIMETAG
...@@ -481,6 +501,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -481,6 +501,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
#ifdef TIMETAG #ifdef TIMETAG
bagging_time += std::chrono::steady_clock::now() - start_time; bagging_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
// need to use subset gradient and hessian
if (is_use_subset_ && bag_data_cnt_ < num_data_) { if (is_use_subset_ && bag_data_cnt_ < num_data_) {
#ifdef TIMETAG #ifdef TIMETAG
start_time = std::chrono::steady_clock::now(); start_time = std::chrono::steady_clock::now();
......
...@@ -74,6 +74,9 @@ public: ...@@ -74,6 +74,9 @@ public:
*/ */
void AddValidDataset(const Dataset* valid_data, void AddValidDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) override; const std::vector<const Metric*>& valid_metrics) override;
void Train(int snapshot_freq, const std::string& model_output_path) override;
/*! /*!
* \brief Training logic * \brief Training logic
* \param gradient nullptr for using default objective, otherwise use self-defined boosting * \param gradient nullptr for using default objective, otherwise use self-defined boosting
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment