Commit 34a2ff2d authored by Guolin Ke's avatar Guolin Ke
Browse files

fix index out of range in predictor.

parent 4de349e7
...@@ -40,7 +40,8 @@ public: ...@@ -40,7 +40,8 @@ public:
boosting->InitPredict(num_iteration); boosting->InitPredict(num_iteration);
boosting_ = boosting; boosting_ = boosting;
num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index); num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index);
predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(boosting_->MaxFeatureIdx() + 1, 0.0f)); num_feature_ = boosting_->MaxFeatureIdx() + 1;
predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(num_feature_, 0.0f));
if (is_predict_leaf_index) { if (is_predict_leaf_index) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
...@@ -137,9 +138,11 @@ private: ...@@ -137,9 +138,11 @@ private:
int loop_size = static_cast<int>(features.size()); int loop_size = static_cast<int>(features.size());
#pragma omp parallel for schedule(static,128) if (loop_size >= 256) #pragma omp parallel for schedule(static,128) if (loop_size >= 256)
for (int i = 0; i < loop_size; ++i) { for (int i = 0; i < loop_size; ++i) {
if (features[i].first < num_feature_) {
pred_buf[features[i].first] = features[i].second; pred_buf[features[i].first] = features[i].second;
} }
} }
}
void ClearPredictBuffer(double* pred_buf, size_t buf_size, const std::vector<std::pair<int, double>>& features) { void ClearPredictBuffer(double* pred_buf, size_t buf_size, const std::vector<std::pair<int, double>>& features) {
if (features.size() < static_cast<size_t>(buf_size / 2)) { if (features.size() < static_cast<size_t>(buf_size / 2)) {
...@@ -157,6 +160,7 @@ private: ...@@ -157,6 +160,7 @@ private:
const Boosting* boosting_; const Boosting* boosting_;
/*! \brief function for prediction */ /*! \brief function for prediction */
PredictFunction predict_fun_; PredictFunction predict_fun_;
int num_feature_;
int num_pred_one_row_; int num_pred_one_row_;
int num_threads_; int num_threads_;
std::vector<std::vector<double>> predict_buf_; std::vector<std::vector<double>> predict_buf_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment