Commit 4e291459 authored by Guolin Ke's avatar Guolin Ke
Browse files

move num_used_model out of predict function

parent 01ed04df
......@@ -55,37 +55,31 @@ public:
/*!
* \brief Prediction for one record, not sigmoid transform
* \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record
*/
virtual double PredictRaw(const double* feature_values,
int num_used_model) const = 0;
virtual double PredictRaw(const double* feature_values) const = 0;
/*!
* \brief Prediction for one record, sigmoid transformation will be used if needed
* \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record
*/
virtual double Predict(const double* feature_values,
int num_used_model) const = 0;
virtual double Predict(const double* feature_values) const = 0;
/*!
* \brief Predtion for one record with leaf index
* \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Predicted leaf index for this record
*/
virtual std::vector<int> PredictLeafIndex(
const double* feature_values,
int num_used_model) const = 0;
const double* feature_values) const = 0;
/*!
* \brief Predtion for multiclass classification
* \param feature_values Feature value on this record
* \return Prediction result, num_class numbers per line
*/
virtual std::vector<double> PredictMulticlass(const double* value, int num_used_model) const = 0;
virtual std::vector<double> PredictMulticlass(const double* value) const = 0;
/*!
* \brief save model to file
......@@ -121,6 +115,11 @@ public:
* \return Number of classes
*/
virtual int NumberOfClass() const = 0;
/*!
* \brief Set number of used model for prediction
*/
virtual void SetNumUsedModel(int num_used_model) = 0;
/*!
* \brief Get Type name of this boosting object
......
......@@ -123,7 +123,7 @@ void Application::LoadData() {
Predictor* predictor = nullptr;
// need to continue train
if (boosting_->NumberOfSubModels() > 0) {
predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index, -1);
predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index);
if (config_.io_config.num_class == 1){
predict_fun =
[&predictor](const std::vector<std::pair<int, double>>& features) {
......@@ -265,9 +265,10 @@ void Application::Train() {
void Application::Predict() {
boosting_->SetNumUsedModel(config_.io_config.num_model_predict);
// create predictor
Predictor predictor(boosting_, config_.io_config.is_sigmoid,
config_.predict_leaf_index, config_.io_config.num_model_predict);
config_.predict_leaf_index);
predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header);
Log::Info("Finish predict.");
......
......@@ -28,9 +28,8 @@ public:
* \param is_sigmoid True if need to predict result with sigmoid transform(if needed, like binary classification)
* \param predict_leaf_index True if output leaf index instead of prediction score
*/
Predictor(const Boosting* boosting, bool is_simgoid, bool is_predict_leaf_index, int num_used_model)
: is_simgoid_(is_simgoid), is_predict_leaf_index_(is_predict_leaf_index),
num_used_model_(num_used_model) {
Predictor(const Boosting* boosting, bool is_simgoid, bool is_predict_leaf_index)
: is_simgoid_(is_simgoid), is_predict_leaf_index_(is_predict_leaf_index) {
boosting_ = boosting;
num_features_ = boosting_->MaxFeatureIdx() + 1;
num_class_ = boosting_->NumberOfClass();
......@@ -64,7 +63,7 @@ public:
std::vector<double> PredictRawOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features);
// get result without sigmoid transformation
return std::vector<double>(1, boosting_->PredictRaw(features_[tid], num_used_model_));
return std::vector<double>(1, boosting_->PredictRaw(features_[tid]));
}
/*!
......@@ -75,7 +74,7 @@ public:
std::vector<int> PredictLeafIndexOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features);
// get result for leaf index
return boosting_->PredictLeafIndex(features_[tid], num_used_model_);
return boosting_->PredictLeafIndex(features_[tid]);
}
/*!
......@@ -86,7 +85,7 @@ public:
std::vector<double> PredictOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed
return std::vector<double>(1, boosting_->Predict(features_[tid], num_used_model_));
return std::vector<double>(1, boosting_->Predict(features_[tid]));
}
/*!
......@@ -97,7 +96,7 @@ public:
std::vector<double> PredictMulticlassOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed
return boosting_->PredictMulticlass(features_[tid], num_used_model_);
return boosting_->PredictMulticlass(features_[tid]);
}
/*!
......@@ -224,8 +223,6 @@ private:
int num_threads_;
/*! \brief True if output leaf index instead of prediction score */
bool is_predict_leaf_index_;
/*! \brief Number of used model */
int num_used_model_;
};
} // namespace LightGBM
......
......@@ -19,7 +19,8 @@ namespace LightGBM {
GBDT::GBDT()
: train_score_updater_(nullptr),
gradients_(nullptr), hessians_(nullptr),
out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr) {
out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr),
saved_model_size_(-1), num_used_model_(0) {
}
GBDT::~GBDT() {
......@@ -43,6 +44,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
const std::vector<const Metric*>& training_metrics) {
gbdt_config_ = dynamic_cast<const GBDTConfig*>(config);
iter_ = 0;
saved_model_size_ = -1;
max_feature_idx_ = 0;
early_stopping_round_ = gbdt_config_->early_stopping_round;
train_data_ = train_data;
......@@ -438,6 +440,7 @@ void GBDT::ModelsFromString(const std::string& model_str) {
}
}
Log::Info("%d models has been loaded\n", models_.size());
num_used_model_ = static_cast<int>(models_.size()) / num_class_;
}
std::string GBDT::FeatureImportance() const {
......@@ -467,23 +470,17 @@ std::string GBDT::FeatureImportance() const {
return str_buf.str();
}
double GBDT::PredictRaw(const double* value, int num_used_model) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size());
}
double GBDT::PredictRaw(const double* value) const {
double ret = 0.0f;
for (int i = 0; i < num_used_model; ++i) {
for (int i = 0; i < num_used_model_; ++i) {
ret += models_[i]->Predict(value);
}
return ret;
}
double GBDT::Predict(const double* value, int num_used_model) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size());
}
double GBDT::Predict(const double* value) const {
double ret = 0.0f;
for (int i = 0; i < num_used_model; ++i) {
for (int i = 0; i < num_used_model_; ++i) {
ret += models_[i]->Predict(value);
}
// if need sigmoid transform
......@@ -493,12 +490,9 @@ double GBDT::Predict(const double* value, int num_used_model) const {
return ret;
}
std::vector<double> GBDT::PredictMulticlass(const double* value, int num_used_model) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size()) / num_class_;
}
std::vector<double> GBDT::PredictMulticlass(const double* value) const {
std::vector<double> ret(num_class_, 0.0f);
for (int i = 0; i < num_used_model; ++i) {
for (int i = 0; i < num_used_model_; ++i) {
for (int j = 0; j < num_class_; ++j){
ret[j] += models_[i * num_class_ + j] -> Predict(value);
}
......@@ -506,12 +500,9 @@ std::vector<double> GBDT::PredictMulticlass(const double* value, int num_used_mo
return ret;
}
std::vector<int> GBDT::PredictLeafIndex(const double* value, int num_used_model) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size());
}
std::vector<int> GBDT::PredictLeafIndex(const double* value) const {
std::vector<int> ret;
for (int i = 0; i < num_used_model; ++i) {
for (int i = 0; i < num_used_model_; ++i) {
ret.push_back(models_[i]->PredictLeafIndex(value));
}
return ret;
......
......@@ -55,33 +55,30 @@ public:
/*!
* \brief Predtion for one record without sigmoid transformation
* \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record
*/
double PredictRaw(const double* feature_values, int num_used_model) const override;
double PredictRaw(const double* feature_values) const override;
/*!
* \brief Predtion for one record with sigmoid transformation if enabled
* \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record
*/
double Predict(const double* feature_values, int num_used_model) const override;
double Predict(const double* feature_values) const override;
/*!
* \brief Predtion for multiclass classification
* \param feature_values Feature value on this record
* \return Prediction result, num_class numbers per line
*/
std::vector<double> PredictMulticlass(const double* value, int num_used_model) const override;
std::vector<double> PredictMulticlass(const double* value) const override;
/*!
* \brief Predtion for one record with leaf index
* \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Predicted leaf index for this record
*/
std::vector<int> PredictLeafIndex(const double* value, int num_used_model) const override;
std::vector<int> PredictLeafIndex(const double* value) const override;
/*!
* \brief Serialize models by string
......@@ -115,6 +112,16 @@ public:
* \return Number of classes
*/
inline int NumberOfClass() const override { return num_class_; }
/*!
* \brief Set number of used model for prediction
*/
inline void SetNumUsedModel(int num_used_model) {
if (num_used_model >= 0) {
num_used_model_ = static_cast<int>(num_used_model / num_class_);
}
}
/*!
* \brief Get Type name of this boosting object
......@@ -208,9 +215,11 @@ private:
/*! \brief Index of label column */
data_size_t label_idx_;
/*! \brief Saved number of models */
int saved_model_size_ = -1;
int saved_model_size_;
/*! \brief File to write models */
std::ofstream model_output_file_;
/*! \brief number of used model */
int num_used_model_;
};
} // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment