"include/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "b752170ba950890dd7e211fdecedef4d8deffb8b"
Commit 13329682 authored by Guolin Ke's avatar Guolin Ke
Browse files

support rollback iteration and reset config during training.

parent 422c0ef7
...@@ -35,6 +35,12 @@ public: ...@@ -35,6 +35,12 @@ public:
const ObjectiveFunction* object_function, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) = 0; const std::vector<const Metric*>& training_metrics) = 0;
/*!
* \brief Reset Config for current boosting
* \param config Configs for boosting
*/
virtual void ResetConfig(const BoostingConfig* config) = 0;
/*! /*!
* \brief Add a validation data * \brief Add a validation data
* \param valid_data Validation data * \param valid_data Validation data
...@@ -52,6 +58,19 @@ public: ...@@ -52,6 +58,19 @@ public:
*/ */
virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0; virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0;
/*!
* \brief Rollback one iteration
*/
virtual void RollbackOneIter() = 0;
/*!
* \brief return current iteration
*/
virtual int GetCurrentIteration() const = 0;
/*!
* \brief Eval metrics and check is met early stopping or not
*/
virtual bool EvalAndCheckEarlyStopping() = 0; virtual bool EvalAndCheckEarlyStopping() = 0;
/*! /*!
* \brief Get evaluation result at data_idx data * \brief Get evaluation result at data_idx data
......
...@@ -239,6 +239,7 @@ DllExport int LGBM_BoosterCreateFromModelfile( ...@@ -239,6 +239,7 @@ DllExport int LGBM_BoosterCreateFromModelfile(
int64_t* out_num_total_model, int64_t* out_num_total_model,
BoosterHandle* out); BoosterHandle* out);
/*! /*!
* \brief free obj in handle * \brief free obj in handle
* \param handle handle to be freed * \param handle handle to be freed
...@@ -246,6 +247,13 @@ DllExport int LGBM_BoosterCreateFromModelfile( ...@@ -246,6 +247,13 @@ DllExport int LGBM_BoosterCreateFromModelfile(
*/ */
DllExport int LGBM_BoosterFree(BoosterHandle handle); DllExport int LGBM_BoosterFree(BoosterHandle handle);
/*!
* \brief Reset config for current booster
* \param parameters format: 'key1=value1 key2=value2'
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters);
/*! /*!
* \brief Get number of class * \brief Get number of class
* \return number of class * \return number of class
...@@ -274,6 +282,19 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, ...@@ -274,6 +282,19 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const float* hess, const float* hess,
int* is_finished); int* is_finished);
/*!
* \brief Rollback one iteration
* \param handle handle
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle);
/*!
* \brief Get iteration of current boosting rounds
* \return iteration of boosting rounds
*/
DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration);
/*! /*!
* \brief Get number of eval * \brief Get number of eval
* \return total number of eval result * \return total number of eval result
......
...@@ -36,6 +36,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -36,6 +36,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
train_data_ = train_data; train_data_ = train_data;
num_class_ = config->num_class; num_class_ = config->num_class;
// create tree learner // create tree learner
tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) { for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config)); auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config));
new_tree_learner->Init(train_data_); new_tree_learner->Init(train_data_);
...@@ -82,6 +83,32 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -82,6 +83,32 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
} }
void GBDT::ResetConfig(const BoostingConfig* config) {
gbdt_config_ = config;
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;
// create tree learner
tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config));
new_tree_learner->Init(train_data_);
// init tree learner
tree_learner_.push_back(std::move(new_tree_learner));
}
tree_learner_.shrink_to_fit();
// if need bagging, create buffer
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_);
bag_data_indices_ = std::vector<data_size_t>(num_data_);
} else {
out_of_bag_data_cnt_ = 0;
out_of_bag_data_indices_.clear();
bag_data_cnt_ = num_data_;
bag_data_indices_.clear();
}
// initialize random generator
random_ = Random(gbdt_config_->bagging_seed);
}
void GBDT::AddDataset(const Dataset* valid_data, void GBDT::AddDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) { const std::vector<const Metric*>& valid_metrics) {
if (iter_ > 0) { if (iter_ > 0) {
...@@ -204,6 +231,25 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -204,6 +231,25 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
} }
void GBDT::RollbackOneIter() {
if (iter_ == 0) { return; }
int cur_iter = iter_ - 1;
// reset score
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = cur_iter * num_class_ + curr_class;
models_[curr_tree]->Shrinkage(-1.0);
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(models_[curr_tree].get(), curr_class);
}
}
// remove model
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
models_.pop_back();
}
--iter_;
}
bool GBDT::EvalAndCheckEarlyStopping() { bool GBDT::EvalAndCheckEarlyStopping() {
bool is_met_early_stopping = false; bool is_met_early_stopping = false;
// print message for metric // print message for metric
......
...@@ -35,6 +35,13 @@ public: ...@@ -35,6 +35,13 @@ public:
void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, const ObjectiveFunction* object_function, void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) const std::vector<const Metric*>& training_metrics)
override; override;
/*!
* \brief Reset Config for current boosting
* \param config Configs for boosting
*/
void ResetConfig(const BoostingConfig* config) override;
/*! /*!
* \brief Adding a validation dataset * \brief Adding a validation dataset
* \param valid_data Validation dataset * \param valid_data Validation dataset
...@@ -51,6 +58,13 @@ public: ...@@ -51,6 +58,13 @@ public:
*/ */
virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override; virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override;
/*!
* \brief Rollback one iteration
*/
void RollbackOneIter() override;
int GetCurrentIteration() const override { return iter_; }
bool EvalAndCheckEarlyStopping() override; bool EvalAndCheckEarlyStopping() override;
/*! /*!
......
...@@ -150,6 +150,17 @@ public: ...@@ -150,6 +150,17 @@ public:
return idx; return idx;
} }
void ResetBoostingConfig(const char* parameters) {
OverallConfig new_config;
new_config.LoadFromString(parameters);
config_.boosting_config = new_config.boosting_config;
boosting_->ResetConfig(&config_.boosting_config);
}
void RollbackOneIter() {
boosting_->RollbackOneIter();
}
const Boosting* GetBoosting() const { return boosting_.get(); } const Boosting* GetBoosting() const { return boosting_.get(); }
private: private:
...@@ -471,6 +482,13 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) { ...@@ -471,6 +482,13 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) {
API_END(); API_END();
} }
DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->ResetBoostingConfig(parameters);
API_END();
}
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) { DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
...@@ -503,6 +521,19 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, ...@@ -503,6 +521,19 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
API_END(); API_END();
} }
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->RollbackOneIter();
API_END();
}
DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
API_END();
}
/*! /*!
* \brief Get number of eval * \brief Get number of eval
* \return total number of eval result * \return total number of eval result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment