Commit 7d35beec authored by Guolin Ke's avatar Guolin Ke
Browse files

fix #1174

parent 314342ce
#ifndef LIGHTGBM_PREDICTOR_HPP_ #ifndef LIGHTGBM_PREDICTOR_HPP_
#define LIGHTGBM_PREDICTOR_HPP_ #define LIGHTGBM_PREDICTOR_HPP_
#define MAX_FEATURE 10000
#define SPARSITY 100
#include <LightGBM/meta.h> #include <LightGBM/meta.h>
#include <LightGBM/boosting.h> #include <LightGBM/boosting.h>
#include <LightGBM/utils/text_reader.h> #include <LightGBM/utils/text_reader.h>
...@@ -63,14 +60,14 @@ public: ...@@ -63,14 +60,14 @@ public:
num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index, is_predict_contrib); num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index, is_predict_contrib);
num_feature_ = boosting_->MaxFeatureIdx() + 1; num_feature_ = boosting_->MaxFeatureIdx() + 1;
predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(num_feature_, 0.0f)); predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(num_feature_, 0.0f));
predict_buf_map_ = std::vector<std::unordered_map<int, double>>(num_threads_); const int kFeatureThreshold = 20000;
const size_t KSparseThreshold = static_cast<size_t>(0.02 * num_feature_);
if (is_predict_leaf_index) { if (is_predict_leaf_index) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this, kFeatureThreshold, KSparseThreshold](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
if (num_feature_ > MAX_FEATURE && num_feature_ / static_cast<int>(features.size()) > SPARSITY) { if (num_feature_ > kFeatureThreshold && features.size() < KSparseThreshold) {
CopyToPredictMap(tid, features); auto buf = CopyToPredictMap(features);
boosting_->PredictLeafIndexByMap(predict_buf_map_[tid], output); boosting_->PredictLeafIndexByMap(buf, output);
ClearPredictMap(tid);
} else { } else {
CopyToPredictBuffer(predict_buf_[tid].data(), features); CopyToPredictBuffer(predict_buf_[tid].data(), features);
// get result for leaf index // get result for leaf index
...@@ -88,12 +85,11 @@ public: ...@@ -88,12 +85,11 @@ public:
}; };
} else { } else {
if (is_raw_score) { if (is_raw_score) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this, kFeatureThreshold, KSparseThreshold](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
if (num_feature_ > MAX_FEATURE && num_feature_ / static_cast<int>(features.size()) > SPARSITY) { if (num_feature_ > kFeatureThreshold && features.size() < KSparseThreshold) {
CopyToPredictMap(tid, features); auto buf = CopyToPredictMap(features);
boosting_->PredictRawByMap(predict_buf_map_[tid], output, &early_stop_); boosting_->PredictRawByMap(buf, output, &early_stop_);
ClearPredictMap(tid);
} else { } else {
CopyToPredictBuffer(predict_buf_[tid].data(), features); CopyToPredictBuffer(predict_buf_[tid].data(), features);
boosting_->PredictRaw(predict_buf_[tid].data(), output, &early_stop_); boosting_->PredictRaw(predict_buf_[tid].data(), output, &early_stop_);
...@@ -101,12 +97,11 @@ public: ...@@ -101,12 +97,11 @@ public:
} }
}; };
} else { } else {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this, kFeatureThreshold, KSparseThreshold](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
if (num_feature_ > MAX_FEATURE && num_feature_ / static_cast<int>(features.size()) > SPARSITY) { if (num_feature_ > kFeatureThreshold && features.size() < KSparseThreshold) {
CopyToPredictMap(tid, features); auto buf = CopyToPredictMap(features);
boosting_->PredictByMap(predict_buf_map_[tid], output, &early_stop_); boosting_->PredictByMap(buf, output, &early_stop_);
ClearPredictMap(tid);
} else { } else {
CopyToPredictBuffer(predict_buf_[tid].data(), features); CopyToPredictBuffer(predict_buf_[tid].data(), features);
boosting_->Predict(predict_buf_[tid].data(), output, &early_stop_); boosting_->Predict(predict_buf_[tid].data(), output, &early_stop_);
...@@ -245,17 +240,15 @@ private: ...@@ -245,17 +240,15 @@ private:
} }
} }
void CopyToPredictMap(int tid, const std::vector<std::pair<int, double>>& features) { std::unordered_map<int, double> CopyToPredictMap(const std::vector<std::pair<int, double>>& features) {
std::unordered_map<int, double> buf;
int loop_size = static_cast<int>(features.size()); int loop_size = static_cast<int>(features.size());
for (int i = 0; i < loop_size; ++i) { for (int i = 0; i < loop_size; ++i) {
if (features[i].first < num_feature_) { if (features[i].first < num_feature_) {
predict_buf_map_[tid][features[i].first] = features[i].second; buf[features[i].first] = features[i].second;
} }
} }
} return std::move(buf);
void ClearPredictMap(int tid) {
predict_buf_map_[tid].clear();
} }
/*! \brief Boosting model */ /*! \brief Boosting model */
...@@ -267,7 +260,6 @@ private: ...@@ -267,7 +260,6 @@ private:
int num_pred_one_row_; int num_pred_one_row_;
int num_threads_; int num_threads_;
std::vector<std::vector<double>> predict_buf_; std::vector<std::vector<double>> predict_buf_;
std::vector<std::unordered_map<int, double>> predict_buf_map_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment