Unverified Commit 1886bf51 authored by Christian Bourjau's avatar Christian Bourjau Committed by GitHub
Browse files

[c++] Avoid copy on Refit (#6478)

parent cd4459a1
...@@ -74,7 +74,7 @@ class LIGHTGBM_EXPORT Boosting { ...@@ -74,7 +74,7 @@ class LIGHTGBM_EXPORT Boosting {
/*! /*!
* \brief Update the tree output by new training data * \brief Update the tree output by new training data
*/ */
virtual void RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) = 0; virtual void RefitTree(const int* tree_leaf_prediction, const size_t nrow, const size_t ncol) = 0;
/*! /*!
* \brief Training logic * \brief Training logic
......
...@@ -226,12 +226,24 @@ void Application::Predict() { ...@@ -226,12 +226,24 @@ void Application::Predict() {
config_.precise_float_parser); config_.precise_float_parser);
TextReader<int> result_reader(config_.output_result.c_str(), false); TextReader<int> result_reader(config_.output_result.c_str(), false);
result_reader.ReadAllLines(); result_reader.ReadAllLines();
std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size());
size_t nrow = result_reader.Lines().size();
size_t ncol = 0;
if (nrow > 0) {
ncol = Common::StringToArray<int>(result_reader.Lines()[0], '\t').size();
}
std::vector<int> pred_leaf;
pred_leaf.resize(nrow * ncol);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < static_cast<int>(result_reader.Lines().size()); ++i) { for (int irow = 0; irow < static_cast<int>(nrow); ++irow) {
pred_leaf[i] = Common::StringToArray<int>(result_reader.Lines()[i], '\t'); auto line_vec = Common::StringToArray<int>(result_reader.Lines()[irow], '\t');
CHECK_EQ(line_vec.size(), ncol);
for (int i_row_item = 0; i_row_item < static_cast<int>(ncol); ++i_row_item) {
pred_leaf[irow * ncol + i_row_item] = line_vec[i_row_item];
}
// Free memory // Free memory
result_reader.Lines()[i].clear(); result_reader.Lines()[irow].clear();
} }
DatasetLoader dataset_loader(config_, nullptr, DatasetLoader dataset_loader(config_, nullptr,
config_.num_class, config_.data.c_str()); config_.num_class, config_.data.c_str());
...@@ -242,7 +254,8 @@ void Application::Predict() { ...@@ -242,7 +254,8 @@ void Application::Predict() {
objective_fun_->Init(train_data_->metadata(), train_data_->num_data()); objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
boosting_->Init(&config_, train_data_.get(), objective_fun_.get(), boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
boosting_->RefitTree(pred_leaf);
boosting_->RefitTree(pred_leaf.data(), nrow, ncol);
boosting_->SaveModelToFile(0, -1, config_.saved_feature_importance_type, boosting_->SaveModelToFile(0, -1, config_.saved_feature_importance_type,
config_.output_model.c_str()); config_.output_model.c_str());
Log::Info("Finished RefitTree"); Log::Info("Finished RefitTree");
......
...@@ -249,32 +249,34 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) { ...@@ -249,32 +249,34 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
} }
} }
void GBDT::RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) { void GBDT::RefitTree(const int* tree_leaf_prediction, const size_t nrow, const size_t ncol) {
CHECK_GT(tree_leaf_prediction.size(), 0); CHECK_GT(nrow * ncol, 0);
CHECK_EQ(static_cast<size_t>(num_data_), tree_leaf_prediction.size()); CHECK_EQ(static_cast<size_t>(num_data_), nrow);
CHECK_EQ(static_cast<size_t>(models_.size()), tree_leaf_prediction[0].size()); CHECK_EQ(models_.size(), ncol);
int num_iterations = static_cast<int>(models_.size() / num_tree_per_iteration_); int num_iterations = static_cast<int>(models_.size() / num_tree_per_iteration_);
std::vector<int> leaf_pred(num_data_); std::vector<int> leaf_pred(num_data_);
if (linear_tree_) { if (linear_tree_) {
std::vector<int> max_leaves_by_thread = std::vector<int>(OMP_NUM_THREADS(), 0); std::vector<int> max_leaves_by_thread = std::vector<int>(OMP_NUM_THREADS(), 0);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < static_cast<int>(tree_leaf_prediction.size()); ++i) { for (int i = 0; i < static_cast<int>(nrow); ++i) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
for (size_t j = 0; j < tree_leaf_prediction[i].size(); ++j) { for (size_t j = 0; j < ncol; ++j) {
max_leaves_by_thread[tid] = std::max(max_leaves_by_thread[tid], tree_leaf_prediction[i][j]); max_leaves_by_thread[tid] = std::max(max_leaves_by_thread[tid], tree_leaf_prediction[i * ncol + j]);
} }
} }
int max_leaves = *std::max_element(max_leaves_by_thread.begin(), max_leaves_by_thread.end()); int max_leaves = *std::max_element(max_leaves_by_thread.begin(), max_leaves_by_thread.end());
max_leaves += 1; max_leaves += 1;
tree_learner_->InitLinear(train_data_, max_leaves); tree_learner_->InitLinear(train_data_, max_leaves);
} }
for (int iter = 0; iter < num_iterations; ++iter) { for (int iter = 0; iter < num_iterations; ++iter) {
Boosting(); Boosting();
for (int tree_id = 0; tree_id < num_tree_per_iteration_; ++tree_id) { for (int tree_id = 0; tree_id < num_tree_per_iteration_; ++tree_id) {
int model_index = iter * num_tree_per_iteration_ + tree_id; int model_index = iter * num_tree_per_iteration_ + tree_id;
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < num_data_; ++i) { for (int i = 0; i < num_data_; ++i) {
leaf_pred[i] = tree_leaf_prediction[i][model_index]; leaf_pred[i] = tree_leaf_prediction[i * ncol + model_index];
CHECK_LT(leaf_pred[i], models_[model_index]->num_leaves()); CHECK_LT(leaf_pred[i], models_[model_index]->num_leaves());
} }
size_t offset = static_cast<size_t>(tree_id) * num_data_; size_t offset = static_cast<size_t>(tree_id) * num_data_;
......
...@@ -143,7 +143,7 @@ class GBDT : public GBDTBase { ...@@ -143,7 +143,7 @@ class GBDT : public GBDTBase {
*/ */
void Train(int snapshot_freq, const std::string& model_output_path) override; void Train(int snapshot_freq, const std::string& model_output_path) override;
void RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) override; void RefitTree(const int* tree_leaf_prediction, const size_t nrow, const size_t ncol) override;
/*! /*!
* \brief Training logic * \brief Training logic
......
...@@ -409,13 +409,7 @@ class Booster { ...@@ -409,13 +409,7 @@ class Booster {
void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) { void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
UNIQUE_LOCK(mutex_) UNIQUE_LOCK(mutex_)
std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0)); boosting_->RefitTree(leaf_preds, nrow, ncol);
for (int i = 0; i < nrow; ++i) {
for (int j = 0; j < ncol; ++j) {
v_leaf_preds[i][j] = leaf_preds[static_cast<size_t>(i) * static_cast<size_t>(ncol) + static_cast<size_t>(j)];
}
}
boosting_->RefitTree(v_leaf_preds);
} }
bool TrainOneIter(const score_t* gradients, const score_t* hessians) { bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment