Commit 4e291459 authored by Guolin Ke's avatar Guolin Ke
Browse files

move num_used_model out of predict function

parent 01ed04df
...@@ -55,37 +55,31 @@ public: ...@@ -55,37 +55,31 @@ public:
/*! /*!
* \brief Prediction for one record, not sigmoid transform * \brief Prediction for one record, not sigmoid transform
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record * \return Prediction result for this record
*/ */
virtual double PredictRaw(const double* feature_values, virtual double PredictRaw(const double* feature_values) const = 0;
int num_used_model) const = 0;
/*! /*!
* \brief Prediction for one record, sigmoid transformation will be used if needed * \brief Prediction for one record, sigmoid transformation will be used if needed
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record * \return Prediction result for this record
*/ */
virtual double Predict(const double* feature_values, virtual double Predict(const double* feature_values) const = 0;
int num_used_model) const = 0;
/*! /*!
* \brief Predtion for one record with leaf index * \brief Predtion for one record with leaf index
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Predicted leaf index for this record * \return Predicted leaf index for this record
*/ */
virtual std::vector<int> PredictLeafIndex( virtual std::vector<int> PredictLeafIndex(
const double* feature_values, const double* feature_values) const = 0;
int num_used_model) const = 0;
/*! /*!
* \brief Predtion for multiclass classification * \brief Predtion for multiclass classification
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result, num_class numbers per line * \return Prediction result, num_class numbers per line
*/ */
virtual std::vector<double> PredictMulticlass(const double* value, int num_used_model) const = 0; virtual std::vector<double> PredictMulticlass(const double* value) const = 0;
/*! /*!
* \brief save model to file * \brief save model to file
...@@ -121,6 +115,11 @@ public: ...@@ -121,6 +115,11 @@ public:
* \return Number of classes * \return Number of classes
*/ */
virtual int NumberOfClass() const = 0; virtual int NumberOfClass() const = 0;
/*!
* \brief Set number of used model for prediction
*/
virtual void SetNumUsedModel(int num_used_model) = 0;
/*! /*!
* \brief Get Type name of this boosting object * \brief Get Type name of this boosting object
......
...@@ -123,7 +123,7 @@ void Application::LoadData() { ...@@ -123,7 +123,7 @@ void Application::LoadData() {
Predictor* predictor = nullptr; Predictor* predictor = nullptr;
// need to continue train // need to continue train
if (boosting_->NumberOfSubModels() > 0) { if (boosting_->NumberOfSubModels() > 0) {
predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index, -1); predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index);
if (config_.io_config.num_class == 1){ if (config_.io_config.num_class == 1){
predict_fun = predict_fun =
[&predictor](const std::vector<std::pair<int, double>>& features) { [&predictor](const std::vector<std::pair<int, double>>& features) {
...@@ -265,9 +265,10 @@ void Application::Train() { ...@@ -265,9 +265,10 @@ void Application::Train() {
void Application::Predict() { void Application::Predict() {
boosting_->SetNumUsedModel(config_.io_config.num_model_predict);
// create predictor // create predictor
Predictor predictor(boosting_, config_.io_config.is_sigmoid, Predictor predictor(boosting_, config_.io_config.is_sigmoid,
config_.predict_leaf_index, config_.io_config.num_model_predict); config_.predict_leaf_index);
predictor.Predict(config_.io_config.data_filename.c_str(), predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header); config_.io_config.output_result.c_str(), config_.io_config.has_header);
Log::Info("Finish predict."); Log::Info("Finish predict.");
......
...@@ -28,9 +28,8 @@ public: ...@@ -28,9 +28,8 @@ public:
* \param is_sigmoid True if need to predict result with sigmoid transform(if needed, like binary classification) * \param is_sigmoid True if need to predict result with sigmoid transform(if needed, like binary classification)
* \param predict_leaf_index True if output leaf index instead of prediction score * \param predict_leaf_index True if output leaf index instead of prediction score
*/ */
Predictor(const Boosting* boosting, bool is_simgoid, bool is_predict_leaf_index, int num_used_model) Predictor(const Boosting* boosting, bool is_simgoid, bool is_predict_leaf_index)
: is_simgoid_(is_simgoid), is_predict_leaf_index_(is_predict_leaf_index), : is_simgoid_(is_simgoid), is_predict_leaf_index_(is_predict_leaf_index) {
num_used_model_(num_used_model) {
boosting_ = boosting; boosting_ = boosting;
num_features_ = boosting_->MaxFeatureIdx() + 1; num_features_ = boosting_->MaxFeatureIdx() + 1;
num_class_ = boosting_->NumberOfClass(); num_class_ = boosting_->NumberOfClass();
...@@ -64,7 +63,7 @@ public: ...@@ -64,7 +63,7 @@ public:
std::vector<double> PredictRawOneLine(const std::vector<std::pair<int, double>>& features) { std::vector<double> PredictRawOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result without sigmoid transformation // get result without sigmoid transformation
return std::vector<double>(1, boosting_->PredictRaw(features_[tid], num_used_model_)); return std::vector<double>(1, boosting_->PredictRaw(features_[tid]));
} }
/*! /*!
...@@ -75,7 +74,7 @@ public: ...@@ -75,7 +74,7 @@ public:
std::vector<int> PredictLeafIndexOneLine(const std::vector<std::pair<int, double>>& features) { std::vector<int> PredictLeafIndexOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result for leaf index // get result for leaf index
return boosting_->PredictLeafIndex(features_[tid], num_used_model_); return boosting_->PredictLeafIndex(features_[tid]);
} }
/*! /*!
...@@ -86,7 +85,7 @@ public: ...@@ -86,7 +85,7 @@ public:
std::vector<double> PredictOneLine(const std::vector<std::pair<int, double>>& features) { std::vector<double> PredictOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed // get result with sigmoid transform if needed
return std::vector<double>(1, boosting_->Predict(features_[tid], num_used_model_)); return std::vector<double>(1, boosting_->Predict(features_[tid]));
} }
/*! /*!
...@@ -97,7 +96,7 @@ public: ...@@ -97,7 +96,7 @@ public:
std::vector<double> PredictMulticlassOneLine(const std::vector<std::pair<int, double>>& features) { std::vector<double> PredictMulticlassOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed // get result with sigmoid transform if needed
return boosting_->PredictMulticlass(features_[tid], num_used_model_); return boosting_->PredictMulticlass(features_[tid]);
} }
/*! /*!
...@@ -224,8 +223,6 @@ private: ...@@ -224,8 +223,6 @@ private:
int num_threads_; int num_threads_;
/*! \brief True if output leaf index instead of prediction score */ /*! \brief True if output leaf index instead of prediction score */
bool is_predict_leaf_index_; bool is_predict_leaf_index_;
/*! \brief Number of used model */
int num_used_model_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -19,7 +19,8 @@ namespace LightGBM { ...@@ -19,7 +19,8 @@ namespace LightGBM {
GBDT::GBDT() GBDT::GBDT()
: train_score_updater_(nullptr), : train_score_updater_(nullptr),
gradients_(nullptr), hessians_(nullptr), gradients_(nullptr), hessians_(nullptr),
out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr) { out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr),
saved_model_size_(-1), num_used_model_(0) {
} }
GBDT::~GBDT() { GBDT::~GBDT() {
...@@ -43,6 +44,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -43,6 +44,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
const std::vector<const Metric*>& training_metrics) { const std::vector<const Metric*>& training_metrics) {
gbdt_config_ = dynamic_cast<const GBDTConfig*>(config); gbdt_config_ = dynamic_cast<const GBDTConfig*>(config);
iter_ = 0; iter_ = 0;
saved_model_size_ = -1;
max_feature_idx_ = 0; max_feature_idx_ = 0;
early_stopping_round_ = gbdt_config_->early_stopping_round; early_stopping_round_ = gbdt_config_->early_stopping_round;
train_data_ = train_data; train_data_ = train_data;
...@@ -438,6 +440,7 @@ void GBDT::ModelsFromString(const std::string& model_str) { ...@@ -438,6 +440,7 @@ void GBDT::ModelsFromString(const std::string& model_str) {
} }
} }
Log::Info("%d models has been loaded\n", models_.size()); Log::Info("%d models has been loaded\n", models_.size());
num_used_model_ = static_cast<int>(models_.size()) / num_class_;
} }
std::string GBDT::FeatureImportance() const { std::string GBDT::FeatureImportance() const {
...@@ -467,23 +470,17 @@ std::string GBDT::FeatureImportance() const { ...@@ -467,23 +470,17 @@ std::string GBDT::FeatureImportance() const {
return str_buf.str(); return str_buf.str();
} }
double GBDT::PredictRaw(const double* value, int num_used_model) const { double GBDT::PredictRaw(const double* value) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size());
}
double ret = 0.0f; double ret = 0.0f;
for (int i = 0; i < num_used_model; ++i) { for (int i = 0; i < num_used_model_; ++i) {
ret += models_[i]->Predict(value); ret += models_[i]->Predict(value);
} }
return ret; return ret;
} }
double GBDT::Predict(const double* value, int num_used_model) const { double GBDT::Predict(const double* value) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size());
}
double ret = 0.0f; double ret = 0.0f;
for (int i = 0; i < num_used_model; ++i) { for (int i = 0; i < num_used_model_; ++i) {
ret += models_[i]->Predict(value); ret += models_[i]->Predict(value);
} }
// if need sigmoid transform // if need sigmoid transform
...@@ -493,12 +490,9 @@ double GBDT::Predict(const double* value, int num_used_model) const { ...@@ -493,12 +490,9 @@ double GBDT::Predict(const double* value, int num_used_model) const {
return ret; return ret;
} }
std::vector<double> GBDT::PredictMulticlass(const double* value, int num_used_model) const { std::vector<double> GBDT::PredictMulticlass(const double* value) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size()) / num_class_;
}
std::vector<double> ret(num_class_, 0.0f); std::vector<double> ret(num_class_, 0.0f);
for (int i = 0; i < num_used_model; ++i) { for (int i = 0; i < num_used_model_; ++i) {
for (int j = 0; j < num_class_; ++j){ for (int j = 0; j < num_class_; ++j){
ret[j] += models_[i * num_class_ + j] -> Predict(value); ret[j] += models_[i * num_class_ + j] -> Predict(value);
} }
...@@ -506,12 +500,9 @@ std::vector<double> GBDT::PredictMulticlass(const double* value, int num_used_mo ...@@ -506,12 +500,9 @@ std::vector<double> GBDT::PredictMulticlass(const double* value, int num_used_mo
return ret; return ret;
} }
std::vector<int> GBDT::PredictLeafIndex(const double* value, int num_used_model) const { std::vector<int> GBDT::PredictLeafIndex(const double* value) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size());
}
std::vector<int> ret; std::vector<int> ret;
for (int i = 0; i < num_used_model; ++i) { for (int i = 0; i < num_used_model_; ++i) {
ret.push_back(models_[i]->PredictLeafIndex(value)); ret.push_back(models_[i]->PredictLeafIndex(value));
} }
return ret; return ret;
......
...@@ -55,33 +55,30 @@ public: ...@@ -55,33 +55,30 @@ public:
/*! /*!
* \brief Predtion for one record without sigmoid transformation * \brief Predtion for one record without sigmoid transformation
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record * \return Prediction result for this record
*/ */
double PredictRaw(const double* feature_values, int num_used_model) const override; double PredictRaw(const double* feature_values) const override;
/*! /*!
* \brief Predtion for one record with sigmoid transformation if enabled * \brief Predtion for one record with sigmoid transformation if enabled
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record * \return Prediction result for this record
*/ */
double Predict(const double* feature_values, int num_used_model) const override; double Predict(const double* feature_values) const override;
/*! /*!
* \brief Predtion for multiclass classification * \brief Predtion for multiclass classification
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result, num_class numbers per line * \return Prediction result, num_class numbers per line
*/ */
std::vector<double> PredictMulticlass(const double* value, int num_used_model) const override; std::vector<double> PredictMulticlass(const double* value) const override;
/*! /*!
* \brief Predtion for one record with leaf index * \brief Predtion for one record with leaf index
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Predicted leaf index for this record * \return Predicted leaf index for this record
*/ */
std::vector<int> PredictLeafIndex(const double* value, int num_used_model) const override; std::vector<int> PredictLeafIndex(const double* value) const override;
/*! /*!
* \brief Serialize models by string * \brief Serialize models by string
...@@ -115,6 +112,16 @@ public: ...@@ -115,6 +112,16 @@ public:
* \return Number of classes * \return Number of classes
*/ */
inline int NumberOfClass() const override { return num_class_; } inline int NumberOfClass() const override { return num_class_; }
/*!
* \brief Set number of used model for prediction
*/
inline void SetNumUsedModel(int num_used_model) {
if (num_used_model >= 0) {
num_used_model_ = static_cast<int>(num_used_model / num_class_);
}
}
/*! /*!
* \brief Get Type name of this boosting object * \brief Get Type name of this boosting object
...@@ -208,9 +215,11 @@ private: ...@@ -208,9 +215,11 @@ private:
/*! \brief Index of label column */ /*! \brief Index of label column */
data_size_t label_idx_; data_size_t label_idx_;
/*! \brief Saved number of models */ /*! \brief Saved number of models */
int saved_model_size_ = -1; int saved_model_size_;
/*! \brief File to write models */ /*! \brief File to write models */
std::ofstream model_output_file_; std::ofstream model_output_file_;
/*! \brief number of used model */
int num_used_model_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment