Unverified Commit 92f2a570 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

support refit model by new data (#1124)

* add code for refit tree

* add implementation.

* update documents.

* clean code

* fix a type
parent aa78a6b9
...@@ -39,7 +39,9 @@ Core Parameters ...@@ -39,7 +39,9 @@ Core Parameters
- path of config file - path of config file
- ``task``, default=\ ``train``, type=enum, options=\ ``train``, ``predict``, ``convert_model`` - **Note**: Only can be used in CLI version.
- ``task``, default=\ ``train``, type=enum, options=\ ``train``, ``predict``, ``convert_model``, ``refit``
- ``train``, alias=\ ``training``, for training - ``train``, alias=\ ``training``, for training
...@@ -47,6 +49,10 @@ Core Parameters ...@@ -47,6 +49,10 @@ Core Parameters
- ``convert_model``, for converting model file into if-else format, see more information in `Convert model parameters <#convert-model-parameters>`__ - ``convert_model``, for converting model file into if-else format, see more information in `Convert model parameters <#convert-model-parameters>`__
- ``refit``, alias=\ ``refit_tree``, refit existing models with new data.
- **Note**: Only can be used in CLI version.
- ``application``, default=\ ``regression``, type=enum, - ``application``, default=\ ``regression``, type=enum,
options=\ ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``quantile_l2``, options=\ ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``quantile_l2``,
``binary``, ``multiclass``, ``multiclassova``, ``xentropy``, ``xentlambda``, ``lambdarank``, ``binary``, ``multiclass``, ``multiclassova``, ``xentropy``, ``xentlambda``, ``lambdarank``,
......
...@@ -73,7 +73,7 @@ private: ...@@ -73,7 +73,7 @@ private:
inline void Application::Run() { inline void Application::Run() {
if (config_.task_type == TaskType::kPredict) { if (config_.task_type == TaskType::kPredict || config_.task_type == TaskType::KRefitTree) {
InitPredict(); InitPredict();
Predict(); Predict();
} else if (config_.task_type == TaskType::kConvertModel) { } else if (config_.task_type == TaskType::kConvertModel) {
......
...@@ -49,6 +49,8 @@ public: ...@@ -49,6 +49,8 @@ public:
virtual void ResetConfig(const BoostingConfig* config) = 0; virtual void ResetConfig(const BoostingConfig* config) = 0;
/*! /*!
* \brief Add a validation data * \brief Add a validation data
* \param valid_data Validation data * \param valid_data Validation data
...@@ -59,6 +61,11 @@ public: ...@@ -59,6 +61,11 @@ public:
virtual void Train(int snapshot_freq, const std::string& model_output_path) = 0; virtual void Train(int snapshot_freq, const std::string& model_output_path) = 0;
/*!
* \brief Update the tree output by new training data
*/
virtual void RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) = 0;
/*! /*!
* \brief Training logic * \brief Training logic
* \param gradients nullptr for using default objective, otherwise use self-defined boosting * \param gradients nullptr for using default objective, otherwise use self-defined boosting
......
...@@ -87,7 +87,7 @@ public: ...@@ -87,7 +87,7 @@ public:
/*! \brief Types of tasks */ /*! \brief Types of tasks */
enum TaskType { enum TaskType {
kTrain, kPredict, kConvertModel kTrain, kPredict, kConvertModel, KRefitTree
}; };
/*! \brief Config for input and output files */ /*! \brief Config for input and output files */
......
...@@ -146,6 +146,10 @@ public: ...@@ -146,6 +146,10 @@ public:
shrinkage_ *= rate; shrinkage_ *= rate;
} }
inline double shrinkage() const {
return shrinkage_;
}
inline void AddBias(double val) { inline void AddBias(double val) {
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048) #pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
for (int i = 0; i < num_leaves_; ++i) { for (int i = 0; i < num_leaves_; ++i) {
......
...@@ -50,6 +50,9 @@ public: ...@@ -50,6 +50,9 @@ public:
*/ */
virtual Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const = 0; virtual Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const = 0;
virtual Tree* FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred,
const score_t* gradients, const score_t* hessians) = 0;
/*! /*!
* \brief Set bagging data * \brief Set bagging data
* \param used_indices Used data indices * \param used_indices Used data indices
......
...@@ -87,7 +87,7 @@ void Application::LoadData() { ...@@ -87,7 +87,7 @@ void Application::LoadData() {
PredictFunction predict_fun = nullptr; PredictFunction predict_fun = nullptr;
PredictionEarlyStopInstance pred_early_stop = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig()); PredictionEarlyStopInstance pred_early_stop = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
// need to continue training // need to continue training
if (boosting_->NumberOfTotalModel() > 0) { if (boosting_->NumberOfTotalModel() > 0 || config_.task_type != TaskType::KRefitTree) {
predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, false, -1, -1)); predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, false, -1, -1));
predict_fun = predictor->GetPredictFunction(); predict_fun = predictor->GetPredictFunction();
} }
...@@ -212,6 +212,34 @@ void Application::Train() { ...@@ -212,6 +212,34 @@ void Application::Train() {
} }
void Application::Predict() { void Application::Predict() {
if (config_.task_type == TaskType::KRefitTree) {
// create predictor
Predictor predictor(boosting_.get(), -1, false, true, false, false, 1, 1);
predictor.Predict(config_.io_config.data_filename.c_str(), config_.io_config.output_result.c_str(), config_.io_config.has_header);
TextReader<int> result_reader(config_.io_config.output_result.c_str(), false);
result_reader.ReadAllLines();
std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size());
#pragma omp parallel for schedule(static)
for (int i = 0; i < static_cast<int>(result_reader.Lines().size()); ++i) {
pred_leaf[i] = Common::StringToArray<int>(result_reader.Lines()[i], '\t');
// Free memory
result_reader.Lines()[i].clear();
}
DatasetLoader dataset_loader(config_.io_config, nullptr,
config_.boosting_config.num_class, config_.io_config.data_filename.c_str());
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), config_.io_config.initscore_filename.c_str(),
0, 1));
train_metric_.clear();
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
boosting_->Init(&config_.boosting_config, train_data_.get(), objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
boosting_->RefitTree(pred_leaf);
boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str());
Log::Info("Finished RefitTree");
} else {
// create predictor // create predictor
Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score, Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score,
config_.io_config.is_predict_leaf_index, config_.io_config.is_predict_contrib, config_.io_config.is_predict_leaf_index, config_.io_config.is_predict_contrib,
...@@ -220,6 +248,7 @@ void Application::Predict() { ...@@ -220,6 +248,7 @@ void Application::Predict() {
predictor.Predict(config_.io_config.data_filename.c_str(), predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header); config_.io_config.output_result.c_str(), config_.io_config.has_header);
Log::Info("Finished prediction"); Log::Info("Finished prediction");
}
} }
void Application::InitPredict() { void Application::InitPredict() {
......
...@@ -42,6 +42,8 @@ public: ...@@ -42,6 +42,8 @@ public:
early_stop_ = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig()); early_stop_ = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
if (early_stop && !boosting->NeedAccuratePrediction()) { if (early_stop && !boosting->NeedAccuratePrediction()) {
PredictionEarlyStopConfig pred_early_stop_config; PredictionEarlyStopConfig pred_early_stop_config;
CHECK(early_stop_freq > 0);
CHECK(early_stop_margin >= 0);
pred_early_stop_config.margin_threshold = early_stop_margin; pred_early_stop_config.margin_threshold = early_stop_margin;
pred_early_stop_config.round_period = early_stop_freq; pred_early_stop_config.round_period = early_stop_freq;
if (boosting->NumberOfClasses() == 1) { if (boosting->NumberOfClasses() == 1) {
......
...@@ -353,6 +353,30 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) { ...@@ -353,6 +353,30 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
} }
} }
void GBDT::RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) {
CHECK(tree_leaf_prediction.size() > 0);
CHECK(static_cast<size_t>(num_data_) == tree_leaf_prediction.size());
CHECK(static_cast<size_t>(models_.size()) == tree_leaf_prediction[0].size());
int num_iterations = static_cast<int>(models_.size() / num_tree_per_iteration_);
std::vector<int> leaf_pred(num_data_);
for (int iter = 0; iter < num_iterations; ++iter) {
Boosting();
for (int tree_id = 0; tree_id < num_tree_per_iteration_; ++tree_id) {
int model_index = iter * num_tree_per_iteration_ + tree_id;
#pragma omp parallel for schedule(static)
for (int i = 0; i < num_data_; ++i) {
leaf_pred[i] = tree_leaf_prediction[i][model_index];
}
size_t bias = static_cast<size_t>(tree_id) * num_data_;
auto grad = gradients_.data() + bias;
auto hess = hessians_.data() + bias;
auto new_tree = tree_learner_->FitByExistingTree(models_[model_index].get(), leaf_pred, grad, hess);
train_score_updater_->AddScore(tree_learner_.get(), new_tree, tree_id);
models_[model_index].reset(new_tree);
}
}
}
double GBDT::BoostFromAverage() { double GBDT::BoostFromAverage() {
// boosting from average label; or customized "average" if implemented for the current objective // boosting from average label; or customized "average" if implemented for the current objective
if (models_.empty() if (models_.empty()
......
...@@ -96,6 +96,8 @@ public: ...@@ -96,6 +96,8 @@ public:
*/ */
void Train(int snapshot_freq, const std::string& model_output_path) override; void Train(int snapshot_freq, const std::string& model_output_path) override;
void RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) override;
/*! /*!
* \brief Training logic * \brief Training logic
* \param gradients nullptr for using default objective, otherwise use self-defined boosting * \param gradients nullptr for using default objective, otherwise use self-defined boosting
......
...@@ -103,6 +103,8 @@ void GetTaskType(const std::unordered_map<std::string, std::string>& params, Tas ...@@ -103,6 +103,8 @@ void GetTaskType(const std::unordered_map<std::string, std::string>& params, Tas
*task_type = TaskType::kPredict; *task_type = TaskType::kPredict;
} else if (value == std::string("convert_model")) { } else if (value == std::string("convert_model")) {
*task_type = TaskType::kConvertModel; *task_type = TaskType::kConvertModel;
} else if (value == std::string("refit") || value == std::string("refit_tree")) {
*task_type = TaskType::KRefitTree;
} else { } else {
Log::Fatal("Unknown task type %s", value.c_str()); Log::Fatal("Unknown task type %s", value.c_str());
} }
......
...@@ -71,6 +71,21 @@ public: ...@@ -71,6 +71,21 @@ public:
} }
} }
void ResetByLeafPred(const std::vector<int>& leaf_pred, int num_leaves) {
ResetLeaves(num_leaves);
std::vector<std::vector<data_size_t>> indices_per_leaf(num_leaves_);
for (data_size_t i = 0; i < static_cast<data_size_t>(leaf_pred.size()); ++i) {
indices_per_leaf[leaf_pred[i]].push_back(i);
}
data_size_t offset = 0;
for (int i = 0; i < num_leaves_; ++i) {
leaf_begin_[i] = offset;
leaf_count_[i] = static_cast<data_size_t>(indices_per_leaf[i].size());
std::copy(indices_per_leaf[i].begin(), indices_per_leaf[i].end(), indices_.begin() + leaf_begin_[i]);
offset += leaf_count_[i];
}
}
/*! /*!
* \brief Get the data indices of one leaf * \brief Get the data indices of one leaf
* \param leaf index of leaf * \param leaf index of leaf
......
...@@ -215,23 +215,26 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* ...@@ -215,23 +215,26 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t*
data_size_t cnt_leaf_data = 0; data_size_t cnt_leaf_data = 0;
auto tmp_idx = data_partition_->GetIndexOnLeaf(i, &cnt_leaf_data); auto tmp_idx = data_partition_->GetIndexOnLeaf(i, &cnt_leaf_data);
double sum_grad = 0.0f; double sum_grad = 0.0f;
double sum_hess = 0.0f; double sum_hess = kEpsilon;
for (data_size_t j = 0; j < cnt_leaf_data; ++j) { for (data_size_t j = 0; j < cnt_leaf_data; ++j) {
auto idx = tmp_idx[j]; auto idx = tmp_idx[j];
sum_grad += gradients[idx]; sum_grad += gradients[idx];
sum_hess += hessians[idx]; sum_hess += hessians[idx];
} }
// avoid zero hessians.
if (sum_hess <= 0) sum_hess = kEpsilon;
double output = FeatureHistogram::CalculateSplittedLeafOutput(sum_grad, sum_hess, double output = FeatureHistogram::CalculateSplittedLeafOutput(sum_grad, sum_hess,
tree_config_->lambda_l1, tree_config_->lambda_l2); tree_config_->lambda_l1, tree_config_->lambda_l2);
tree->SetLeafOutput(i, output); tree->SetLeafOutput(i, output* tree->shrinkage());
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
return tree.release(); return tree.release();
} }
Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred, const score_t* gradients, const score_t *hessians) {
data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves());
return FitByExistingTree(old_tree, gradients, hessians);
}
void SerialTreeLearner::BeforeTrain() { void SerialTreeLearner::BeforeTrain() {
// reset histogram pool // reset histogram pool
......
...@@ -45,6 +45,9 @@ public: ...@@ -45,6 +45,9 @@ public:
Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const override; Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const override;
Tree* FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred,
const score_t* gradients, const score_t* hessians) override;
void SetBaggingData(const data_size_t* used_indices, data_size_t num_data) override { void SetBaggingData(const data_size_t* used_indices, data_size_t num_data) override {
data_partition_->SetUsedDataIndices(used_indices, num_data); data_partition_->SetUsedDataIndices(used_indices, num_data);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment