Commit 89c69987 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in continued-train.

parent 922013b0
...@@ -58,6 +58,7 @@ catch(...) { omp_except_helper.CaptureException(); } ...@@ -58,6 +58,7 @@ catch(...) { omp_except_helper.CaptureException(); }
simulate a single thread running. simulate a single thread running.
All #pragma omp should be ignored by the compiler **/ All #pragma omp should be ignored by the compiler **/
inline void omp_set_num_threads(int) {} inline void omp_set_num_threads(int) {}
inline void omp_set_nested(int) {}
inline int omp_get_num_threads() {return 1;} inline int omp_get_num_threads() {return 1;}
inline int omp_get_thread_num() {return 0;} inline int omp_get_thread_num() {return 0;}
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -35,6 +35,7 @@ Application::Application(int argc, char** argv) { ...@@ -35,6 +35,7 @@ Application::Application(int argc, char** argv) {
if (config_.io_config.data_filename.size() == 0 && config_.task_type != TaskType::kConvertModel) { if (config_.io_config.data_filename.size() == 0 && config_.task_type != TaskType::kConvertModel) {
Log::Fatal("No training/prediction data, application quit"); Log::Fatal("No training/prediction data, application quit");
} }
omp_set_nested(0);
} }
Application::~Application() { Application::~Application() {
......
...@@ -32,32 +32,39 @@ public: ...@@ -32,32 +32,39 @@ public:
*/ */
Predictor(Boosting* boosting, int num_iteration, Predictor(Boosting* boosting, int num_iteration,
bool is_raw_score, bool is_predict_leaf_index) { bool is_raw_score, bool is_predict_leaf_index) {
#pragma omp parallel
#pragma omp master
{
num_threads_ = omp_get_num_threads();
}
boosting->InitPredict(num_iteration); boosting->InitPredict(num_iteration);
boosting_ = boosting; boosting_ = boosting;
num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index); num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index);
predict_buf_ = std::vector<double>(boosting_->MaxFeatureIdx() + 1, 0.0f); predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(boosting_->MaxFeatureIdx() + 1, 0.0f));
if (is_predict_leaf_index) { if (is_predict_leaf_index) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
CopyToPredictBuffer(features); int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features);
// get result for leaf index // get result for leaf index
boosting_->PredictLeafIndex(predict_buf_.data(), output); boosting_->PredictLeafIndex(predict_buf_[tid].data(), output);
ClearPredictBuffer(features); ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
}; };
} else { } else {
if (is_raw_score) { if (is_raw_score) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
CopyToPredictBuffer(features); int tid = omp_get_thread_num();
boosting_->PredictRaw(predict_buf_.data(), output); CopyToPredictBuffer(predict_buf_[tid].data(), features);
ClearPredictBuffer(features); boosting_->PredictRaw(predict_buf_[tid].data(), output);
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
}; };
} else { } else {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
CopyToPredictBuffer(features); int tid = omp_get_thread_num();
boosting_->Predict(predict_buf_.data(), output); CopyToPredictBuffer(predict_buf_[tid].data(), features);
ClearPredictBuffer(features); boosting_->Predict(predict_buf_[tid].data(), output);
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
}; };
} }
} }
...@@ -126,22 +133,22 @@ public: ...@@ -126,22 +133,22 @@ public:
private: private:
void CopyToPredictBuffer(const std::vector<std::pair<int, double>>& features) { void CopyToPredictBuffer(double* pred_buf, const std::vector<std::pair<int, double>>& features) {
int loop_size = static_cast<int>(features.size()); int loop_size = static_cast<int>(features.size());
#pragma omp parallel for schedule(static,128) if (loop_size >= 256) #pragma omp parallel for schedule(static,128) if (loop_size >= 256)
for (int i = 0; i < loop_size; ++i) { for (int i = 0; i < loop_size; ++i) {
predict_buf_[features[i].first] = features[i].second; pred_buf[features[i].first] = features[i].second;
} }
} }
void ClearPredictBuffer(const std::vector<std::pair<int, double>>& features) { void ClearPredictBuffer(double* pred_buf, size_t buf_size, const std::vector<std::pair<int, double>>& features) {
if (features.size() < static_cast<size_t>(predict_buf_.size() / 2)) { if (features.size() < static_cast<size_t>(buf_size / 2)) {
std::memset(predict_buf_.data(), 0, sizeof(double)*(predict_buf_.size())); std::memset(pred_buf, 0, sizeof(double)*(buf_size));
} else { } else {
int loop_size = static_cast<int>(features.size()); int loop_size = static_cast<int>(features.size());
#pragma omp parallel for schedule(static,128) if (loop_size >= 256) #pragma omp parallel for schedule(static,128) if (loop_size >= 256)
for (int i = 0; i < loop_size; ++i) { for (int i = 0; i < loop_size; ++i) {
predict_buf_[features[i].first] = 0.0f; pred_buf[features[i].first] = 0.0f;
} }
} }
} }
...@@ -151,7 +158,8 @@ private: ...@@ -151,7 +158,8 @@ private:
/*! \brief function for prediction */ /*! \brief function for prediction */
PredictFunction predict_fun_; PredictFunction predict_fun_;
int num_pred_one_row_; int num_pred_one_row_;
std::vector<double> predict_buf_; int num_threads_;
std::vector<std::vector<double>> predict_buf_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment