"python-package/vscode:/vscode.git/clone" did not exist on "bdb02e05ad5c870dbb082409eda7388879de2452"
Commit 89c69987 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in continued-train.

parent 922013b0
......@@ -58,6 +58,7 @@ catch(...) { omp_except_helper.CaptureException(); }
simulate a single thread running.
All #pragma omp should be ignored by the compiler **/
inline void omp_set_num_threads(int) {}
inline void omp_set_nested(int) {}
inline int omp_get_num_threads() {return 1;}
inline int omp_get_thread_num() {return 0;}
#ifdef __cplusplus
......
......@@ -35,6 +35,7 @@ Application::Application(int argc, char** argv) {
if (config_.io_config.data_filename.size() == 0 && config_.task_type != TaskType::kConvertModel) {
Log::Fatal("No training/prediction data, application quit");
}
omp_set_nested(0);
}
Application::~Application() {
......
......@@ -32,32 +32,39 @@ public:
*/
Predictor(Boosting* boosting, int num_iteration,
bool is_raw_score, bool is_predict_leaf_index) {
#pragma omp parallel
#pragma omp master
{
num_threads_ = omp_get_num_threads();
}
boosting->InitPredict(num_iteration);
boosting_ = boosting;
num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index);
predict_buf_ = std::vector<double>(boosting_->MaxFeatureIdx() + 1, 0.0f);
predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(boosting_->MaxFeatureIdx() + 1, 0.0f));
if (is_predict_leaf_index) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
CopyToPredictBuffer(features);
int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features);
// get result for leaf index
boosting_->PredictLeafIndex(predict_buf_.data(), output);
ClearPredictBuffer(features);
boosting_->PredictLeafIndex(predict_buf_[tid].data(), output);
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
};
} else {
if (is_raw_score) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
CopyToPredictBuffer(features);
boosting_->PredictRaw(predict_buf_.data(), output);
ClearPredictBuffer(features);
int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features);
boosting_->PredictRaw(predict_buf_[tid].data(), output);
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
};
} else {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
CopyToPredictBuffer(features);
boosting_->Predict(predict_buf_.data(), output);
ClearPredictBuffer(features);
int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features);
boosting_->Predict(predict_buf_[tid].data(), output);
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
};
}
}
......@@ -126,22 +133,22 @@ public:
private:
void CopyToPredictBuffer(const std::vector<std::pair<int, double>>& features) {
void CopyToPredictBuffer(double* pred_buf, const std::vector<std::pair<int, double>>& features) {
int loop_size = static_cast<int>(features.size());
#pragma omp parallel for schedule(static,128) if (loop_size >= 256)
for (int i = 0; i < loop_size; ++i) {
predict_buf_[features[i].first] = features[i].second;
pred_buf[features[i].first] = features[i].second;
}
}
void ClearPredictBuffer(const std::vector<std::pair<int, double>>& features) {
if (features.size() < static_cast<size_t>(predict_buf_.size() / 2)) {
std::memset(predict_buf_.data(), 0, sizeof(double)*(predict_buf_.size()));
void ClearPredictBuffer(double* pred_buf, size_t buf_size, const std::vector<std::pair<int, double>>& features) {
if (features.size() < static_cast<size_t>(buf_size / 2)) {
std::memset(pred_buf, 0, sizeof(double)*(buf_size));
} else {
int loop_size = static_cast<int>(features.size());
#pragma omp parallel for schedule(static,128) if (loop_size >= 256)
for (int i = 0; i < loop_size; ++i) {
predict_buf_[features[i].first] = 0.0f;
pred_buf[features[i].first] = 0.0f;
}
}
}
......@@ -151,7 +158,8 @@ private:
/*! \brief function for prediction */
PredictFunction predict_fun_;
int num_pred_one_row_;
std::vector<double> predict_buf_;
int num_threads_;
std::vector<std::vector<double>> predict_buf_;
};
} // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment