Commit a24b7fd4 authored by Allardvm's avatar Allardvm
Browse files

Fixed prediction bug when num_used_model = NO_LIMIT / -1

parent 92351659
...@@ -76,26 +76,26 @@ public: ...@@ -76,26 +76,26 @@ public:
void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) const override; void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) const override;
/*! /*!
* \brief Predtion for one record without sigmoid transformation * \brief Prediction for one record without sigmoid transformation
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result for this record * \return Prediction result for this record
*/ */
std::vector<double> PredictRaw(const double* feature_values) const override; std::vector<double> PredictRaw(const double* feature_values) const override;
/*! /*!
* \brief Predtion for one record with sigmoid transformation if enabled * \brief Prediction for one record with sigmoid transformation if enabled
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result for this record * \return Prediction result for this record
*/ */
std::vector<double> Predict(const double* feature_values) const override; std::vector<double> Predict(const double* feature_values) const override;
/*! /*!
* \brief Predtion for one record with leaf index * \brief Prediction for one record with leaf index
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Predicted leaf index for this record * \return Predicted leaf index for this record
*/ */
std::vector<int> PredictLeafIndex(const double* value) const override; std::vector<int> PredictLeafIndex(const double* value) const override;
/*! /*!
* \brief save model to file * \brief save model to file
* \param num_used_model number of model that want to save, -1 means save all * \param num_used_model number of model that want to save, -1 means save all
...@@ -137,9 +137,11 @@ public: ...@@ -137,9 +137,11 @@ public:
inline void SetNumUsedModel(int num_used_model) { inline void SetNumUsedModel(int num_used_model) {
if (num_used_model >= 0) { if (num_used_model >= 0) {
num_used_model_ = static_cast<int>(num_used_model / num_class_); num_used_model_ = static_cast<int>(num_used_model / num_class_);
} else {
num_used_model_ = static_cast<int>(models_.size()) / num_class_;
} }
} }
/*! /*!
* \brief Get Type name of this boosting object * \brief Get Type name of this boosting object
*/ */
...@@ -218,7 +220,7 @@ protected: ...@@ -218,7 +220,7 @@ protected:
std::vector<data_size_t> bag_data_indices_; std::vector<data_size_t> bag_data_indices_;
/*! \brief Number of in-bag data */ /*! \brief Number of in-bag data */
data_size_t bag_data_cnt_; data_size_t bag_data_cnt_;
/*! \brief Number of traning data */ /*! \brief Number of training data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of classes */ /*! \brief Number of classes */
int num_class_; int num_class_;
...@@ -226,7 +228,7 @@ protected: ...@@ -226,7 +228,7 @@ protected:
Random random_; Random random_;
/*! /*!
* \brief Sigmoid parameter, used for prediction. * \brief Sigmoid parameter, used for prediction.
* if > 0 meas output score will transform by sigmoid function * if > 0 means output score will transform by sigmoid function
*/ */
double sigmoid_; double sigmoid_;
/*! \brief Index of label column */ /*! \brief Index of label column */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment